Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cdc1fa12
"tests/vscode:/vscode.git/clone" did not exist on "42f5e7c52a5852e20937001332572c8cb8115af0"
Unverified
Commit
cdc1fa12
authored
Feb 25, 2025
by
Harry Mellor
Committed by
GitHub
Feb 24, 2025
Browse files
Remove unused kwargs from model definitions (#13555)
parent
f61528d4
Changes
104
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
69 additions
and
288 deletions
+69
-288
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+2
-6
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+2
-7
vllm/model_executor/models/qwen_vl.py
vllm/model_executor/models/qwen_vl.py
+2
-6
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+1
-6
vllm/model_executor/models/solar.py
vllm/model_executor/models/solar.py
+4
-17
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+6
-23
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+6
-20
vllm/model_executor/models/transformers.py
vllm/model_executor/models/transformers.py
+2
-11
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+5
-12
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+13
-76
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+0
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-21
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+5
-14
vllm/worker/cpu_enc_dec_model_runner.py
vllm/worker/cpu_enc_dec_model_runner.py
+0
-4
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+0
-2
vllm/worker/cpu_pooling_model_runner.py
vllm/worker/cpu_pooling_model_runner.py
+0
-14
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+1
-13
vllm/worker/hpu_model_runner.py
vllm/worker/hpu_model_runner.py
+15
-21
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+2
-11
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+2
-2
No files found.
vllm/model_executor/models/qwen2_rm.py
View file @
cdc1fa12
...
@@ -5,12 +5,11 @@
...
@@ -5,12 +5,11 @@
# Copyright 2024 The Qwen team.
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
...
@@ -80,13 +79,10 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -80,13 +79,10 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
logits
,
_
=
self
.
score
(
hidden_states
)
logits
,
_
=
self
.
score
(
hidden_states
)
return
logits
return
logits
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
cdc1fa12
...
@@ -24,8 +24,8 @@
...
@@ -24,8 +24,8 @@
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from
functools
import
cached_property
,
partial
from
functools
import
cached_property
,
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
from
typing
import
(
Any
,
Callable
,
Iterable
,
Literal
,
Mapping
,
Optional
,
Set
,
Set
,
Tuple
,
Type
,
TypedDict
,
Union
)
Tuple
,
Type
,
TypedDict
,
Union
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -38,7 +38,6 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
...
@@ -38,7 +38,6 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig
,
Qwen2VLVisionConfig
)
Qwen2VLConfig
,
Qwen2VLVisionConfig
)
from
transformers.models.qwen2_vl.image_processing_qwen2_vl
import
smart_resize
from
transformers.models.qwen2_vl.image_processing_qwen2_vl
import
smart_resize
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed
import
utils
as
dist_utils
...
@@ -1302,8 +1301,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1302,8 +1301,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -1354,8 +1351,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1354,8 +1351,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
hidden_states
=
self
.
language_model
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
...
...
vllm/model_executor/models/qwen_vl.py
View file @
cdc1fa12
...
@@ -22,7 +22,6 @@ from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
...
@@ -22,7 +22,6 @@ from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
from
transformers.image_utils
import
ImageInput
from
transformers.image_utils
import
ImageInput
from
transformers.tokenization_utils_base
import
TextInput
from
transformers.tokenization_utils_base
import
TextInput
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -766,8 +765,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
...
@@ -766,8 +765,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -783,7 +780,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
...
@@ -783,7 +780,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
vision_embeddings
)
vision_embeddings
)
input_ids
=
None
input_ids
=
None
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
vllm/model_executor/models/roberta.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
itertools
import
itertools
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
RobertaConfig
from
transformers
import
RobertaConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
CrossEncodingPooler
from
vllm.model_executor.layers.pooler
import
CrossEncodingPooler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -243,16 +242,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -243,16 +242,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
roberta
(
input_ids
=
input_ids
,
return
self
.
roberta
(
input_ids
=
input_ids
,
position_ids
=
positions
,
position_ids
=
positions
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
attn_metadata
=
attn_metadata
,
token_type_ids
=
token_type_ids
)
token_type_ids
=
token_type_ids
)
vllm/model_executor/models/solar.py
View file @
cdc1fa12
...
@@ -23,13 +23,13 @@
...
@@ -23,13 +23,13 @@
# limitations under the License.
# limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights."""
"""Inference-only Solar model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -172,13 +172,11 @@ class SolarAttention(nn.Module):
...
@@ -172,13 +172,11 @@ class SolarAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -238,8 +236,6 @@ class SolarDecoderLayer(nn.Module):
...
@@ -238,8 +236,6 @@ class SolarDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -252,8 +248,6 @@ class SolarDecoderLayer(nn.Module):
...
@@ -252,8 +248,6 @@ class SolarDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -315,8 +309,6 @@ class SolarModel(nn.Module):
...
@@ -315,8 +309,6 @@ class SolarModel(nn.Module):
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -357,8 +349,6 @@ class SolarModel(nn.Module):
...
@@ -357,8 +349,6 @@ class SolarModel(nn.Module):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
)
)
...
@@ -438,13 +428,10 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -438,13 +428,10 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
model_output
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
model_output
return
model_output
...
...
vllm/model_executor/models/stablelm.py
View file @
cdc1fa12
...
@@ -20,13 +20,13 @@
...
@@ -20,13 +20,13 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
StableLmConfig
from
transformers
import
StableLmConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
@@ -147,13 +147,11 @@ class StablelmAttention(nn.Module):
...
@@ -147,13 +147,11 @@ class StablelmAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -183,8 +181,6 @@ class StablelmDecoderLayer(nn.Module):
...
@@ -183,8 +181,6 @@ class StablelmDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
residual
=
hidden_states
residual
=
hidden_states
...
@@ -192,8 +188,6 @@ class StablelmDecoderLayer(nn.Module):
...
@@ -192,8 +188,6 @@ class StablelmDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -241,8 +235,6 @@ class StableLMEpochModel(nn.Module):
...
@@ -241,8 +235,6 @@ class StableLMEpochModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -254,14 +246,8 @@ class StableLMEpochModel(nn.Module):
...
@@ -254,14 +246,8 @@ class StableLMEpochModel(nn.Module):
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
...
@@ -296,13 +282,10 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
...
@@ -296,13 +282,10 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/starcoder2.py
View file @
cdc1fa12
...
@@ -19,13 +19,13 @@
...
@@ -19,13 +19,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" PyTorch Starcoder2 model."""
""" PyTorch Starcoder2 model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Starcoder2Config
from
transformers
import
Starcoder2Config
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -118,13 +118,11 @@ class Starcoder2Attention(nn.Module):
...
@@ -118,13 +118,11 @@ class Starcoder2Attention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -184,8 +182,6 @@ class Starcoder2DecoderLayer(nn.Module):
...
@@ -184,8 +182,6 @@ class Starcoder2DecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
residual
=
hidden_states
residual
=
hidden_states
...
@@ -193,8 +189,6 @@ class Starcoder2DecoderLayer(nn.Module):
...
@@ -193,8 +189,6 @@ class Starcoder2DecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -246,8 +240,6 @@ class Starcoder2Model(nn.Module):
...
@@ -246,8 +240,6 @@ class Starcoder2Model(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -259,11 +251,8 @@ class Starcoder2Model(nn.Module):
...
@@ -259,11 +251,8 @@ class Starcoder2Model(nn.Module):
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
)
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
...
@@ -306,13 +295,10 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
...
@@ -306,13 +295,10 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/transformers.py
View file @
cdc1fa12
...
@@ -22,7 +22,7 @@ from torch import nn
...
@@ -22,7 +22,7 @@ from torch import nn
from
transformers
import
AutoModel
,
PreTrainedModel
from
transformers
import
AutoModel
,
PreTrainedModel
from
transformers.modeling_utils
import
ALL_ATTENTION_FUNCTIONS
from
transformers.modeling_utils
import
ALL_ATTENTION_FUNCTIONS
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.utils
import
divide
from
vllm.distributed.utils
import
divide
...
@@ -59,7 +59,6 @@ def vllm_flash_attention_forward(
...
@@ -59,7 +59,6 @@ def vllm_flash_attention_forward(
# Transformers kwargs
# Transformers kwargs
scaling
:
Optional
[
float
]
=
None
,
scaling
:
Optional
[
float
]
=
None
,
# vLLM kwargs
# vLLM kwargs
attn_metadata
:
Optional
[
AttentionMetadata
]
=
None
,
attention_instances
:
Optional
[
list
[
Attention
]]
=
None
,
attention_instances
:
Optional
[
list
[
Attention
]]
=
None
,
**
kwargs
):
**
kwargs
):
self_attn
=
attention_instances
[
module
.
layer_idx
]
self_attn
=
attention_instances
[
module
.
layer_idx
]
...
@@ -68,12 +67,7 @@ def vllm_flash_attention_forward(
...
@@ -68,12 +67,7 @@ def vllm_flash_attention_forward(
hidden
=
query
.
shape
[
-
2
]
hidden
=
query
.
shape
[
-
2
]
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
query
,
key
,
value
=
(
x
.
reshape
(
hidden
,
-
1
)
for
x
in
(
query
,
key
,
value
))
query
,
key
,
value
=
(
x
.
reshape
(
hidden
,
-
1
)
for
x
in
(
query
,
key
,
value
))
return
self_attn
.
forward
(
return
self_attn
.
forward
(
query
,
key
,
value
),
None
query
,
key
,
value
,
kv_cache
=
None
,
# argument not used
attn_metadata
=
attn_metadata
),
None
ALL_ATTENTION_FUNCTIONS
[
"vllm"
]
=
vllm_flash_attention_forward
ALL_ATTENTION_FUNCTIONS
[
"vllm"
]
=
vllm_flash_attention_forward
...
@@ -251,8 +245,6 @@ class TransformersModel(nn.Module, SupportsQuant):
...
@@ -251,8 +245,6 @@ class TransformersModel(nn.Module, SupportsQuant):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
list
[
torch
.
Tensor
],
# argument not used
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -260,7 +252,6 @@ class TransformersModel(nn.Module, SupportsQuant):
...
@@ -260,7 +252,6 @@ class TransformersModel(nn.Module, SupportsQuant):
input_ids
[
None
,
...],
input_ids
[
None
,
...],
use_cache
=
False
,
use_cache
=
False
,
position_ids
=
positions
[
None
,
...],
position_ids
=
positions
[
None
,
...],
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
attention_instances
=
self
.
attention_instances
,
attention_instances
=
self
.
attention_instances
,
return_dict
=
False
)[
0
][
0
,
...]
# we remove batch dimension for now
return_dict
=
False
)[
0
][
0
,
...]
# we remove batch dimension for now
...
...
vllm/model_executor/models/ultravox.py
View file @
cdc1fa12
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
"""PyTorch Ultravox model."""
"""PyTorch Ultravox model."""
import
math
import
math
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
(
Any
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
from
typing
import
(
Any
,
Iterable
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
Tuple
,
TypedDict
,
Union
)
TypedDict
,
Union
)
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
...
@@ -16,8 +16,8 @@ from transformers.models.whisper import WhisperFeatureExtractor
...
@@ -16,8 +16,8 @@ from transformers.models.whisper import WhisperFeatureExtractor
from
transformers.models.whisper.modeling_whisper
import
WhisperEncoder
from
transformers.models.whisper.modeling_whisper
import
WhisperEncoder
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
MulAndSilu
,
get_act_fn
from
vllm.model_executor.layers.activation
import
MulAndSilu
,
get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
...
@@ -495,13 +495,13 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
...
@@ -495,13 +495,13 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
NestedTensors
]
=
None
,
multimodal_embeddings
:
Optional
[
NestedTensors
]
=
None
,
attn_metadata
:
Optional
[
AttentionMetadata
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
# TODO(ywang96): remove this block after v0 is deprecated.
# TODO(ywang96): remove this block after v0 is deprecated.
if
not
envs
.
VLLM_USE_V1
:
if
not
envs
.
VLLM_USE_V1
:
attn_metadata
=
get_forward_context
().
attn_metadata
merge_multimodal_embeddings_from_map
(
merge_multimodal_embeddings_from_map
(
inputs_embeds
,
multimodal_embeddings
,
inputs_embeds
,
multimodal_embeddings
,
attn_metadata
.
multi_modal_placeholder_index_maps
[
"audio"
])
attn_metadata
.
multi_modal_placeholder_index_maps
[
"audio"
])
...
@@ -514,8 +514,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
...
@@ -514,8 +514,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
**
kwargs
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -540,17 +538,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
...
@@ -540,17 +538,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
elif
inputs_embeds
is
None
:
elif
inputs_embeds
is
None
:
multimodal_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
multimodal_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
# TODO(ywang96): remove attn_metadata from get_input_embeddings
# after v0 is deprecated
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
multimodal_embeddings
,
multimodal_embeddings
)
attn_metadata
)
input_ids
=
None
input_ids
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/whisper.py
View file @
cdc1fa12
...
@@ -10,7 +10,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
...
@@ -10,7 +10,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor
)
WhisperProcessor
)
from
transformers.models.whisper.modeling_whisper
import
sinusoids
from
transformers.models.whisper.modeling_whisper
import
sinusoids
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -134,13 +134,11 @@ class WhisperAttention(nn.Module):
...
@@ -134,13 +134,11 @@ class WhisperAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
):
):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
...
@@ -196,8 +194,6 @@ class WhisperCrossAttention(WhisperAttention):
...
@@ -196,8 +194,6 @@ class WhisperCrossAttention(WhisperAttention):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
):
):
q
,
_
=
self
.
q_proj
(
hidden_states
)
q
,
_
=
self
.
q_proj
(
hidden_states
)
...
@@ -209,13 +205,7 @@ class WhisperCrossAttention(WhisperAttention):
...
@@ -209,13 +205,7 @@ class WhisperCrossAttention(WhisperAttention):
else
:
else
:
k
=
v
=
None
k
=
v
=
None
attn_output
=
self
.
attn
(
attn_output
=
self
.
attn
(
q
,
k
,
v
)
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
...
@@ -285,16 +275,10 @@ class WhisperEncoderLayer(nn.Module):
...
@@ -285,16 +275,10 @@ class WhisperEncoderLayer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
):
):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
...
@@ -348,14 +332,10 @@ class WhisperDecoderLayer(nn.Module):
...
@@ -348,14 +332,10 @@ class WhisperDecoderLayer(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
):
):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
residual
=
hidden_states
...
@@ -363,8 +343,6 @@ class WhisperDecoderLayer(nn.Module):
...
@@ -363,8 +343,6 @@ class WhisperDecoderLayer(nn.Module):
hidden_states
=
self
.
encoder_attn
(
hidden_states
=
self
.
encoder_attn
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -411,12 +389,7 @@ class WhisperEncoder(nn.Module):
...
@@ -411,12 +389,7 @@ class WhisperEncoder(nn.Module):
self
.
embed_positions
.
weight
.
copy_
(
self
.
embed_positions
.
weight
.
copy_
(
sinusoids
(
*
self
.
embed_positions
.
weight
.
shape
))
sinusoids
(
*
self
.
embed_positions
.
weight
.
shape
))
def
forward
(
def
forward
(
self
,
input_features
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]):
self
,
input_features
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
):
hidden_states
=
[]
hidden_states
=
[]
for
features
in
input_features
:
for
features
in
input_features
:
embeds
=
nn
.
functional
.
gelu
(
self
.
conv1
(
features
))
embeds
=
nn
.
functional
.
gelu
(
self
.
conv1
(
features
))
...
@@ -426,12 +399,8 @@ class WhisperEncoder(nn.Module):
...
@@ -426,12 +399,8 @@ class WhisperEncoder(nn.Module):
hidden_states
.
append
(
embeds
)
hidden_states
.
append
(
embeds
)
hidden_states
=
torch
.
cat
(
hidden_states
)
hidden_states
=
torch
.
cat
(
hidden_states
)
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
=
encoder_layer
(
hidden_states
)
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
return
hidden_states
return
hidden_states
...
@@ -466,19 +435,15 @@ class WhisperDecoder(nn.Module):
...
@@ -466,19 +435,15 @@ class WhisperDecoder(nn.Module):
input_ids
,
input_ids
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
):
):
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
positions
=
self
.
embed_positions
(
positions
)
positions
=
self
.
embed_positions
(
positions
)
hidden_states
=
inputs_embeds
+
positions
hidden_states
=
inputs_embeds
+
positions
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
)
:
for
decoder_layer
in
self
.
layers
:
hidden_states
=
decoder_layer
(
hidden_states
=
decoder_layer
(
hidden_states
,
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
...
@@ -505,36 +470,22 @@ class WhisperModel(nn.Module):
...
@@ -505,36 +470,22 @@ class WhisperModel(nn.Module):
input_features
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]],
input_features
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]],
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
encoder_outputs
=
self
.
get_encoder_outputs
(
encoder_outputs
=
self
.
get_encoder_outputs
(
input_features
)
input_features
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
)
decoder_outputs
=
self
.
decoder
(
decoder_outputs
=
self
.
decoder
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
encoder_hidden_states
=
encoder_outputs
,
encoder_hidden_states
=
encoder_outputs
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
)
)
return
decoder_outputs
return
decoder_outputs
def
get_encoder_outputs
(
def
get_encoder_outputs
(
self
,
self
,
input_features
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]],
input_features
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
if
input_features
is
None
:
if
input_features
is
None
:
return
None
return
None
return
self
.
encoder
(
return
self
.
encoder
(
input_features
)
input_features
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
...
@@ -733,8 +684,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
...
@@ -733,8 +684,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
...
@@ -742,31 +691,19 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
...
@@ -742,31 +691,19 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
input_features
=
audio_input
[
"input_features"
],
input_features
=
audio_input
[
"input_features"
],
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
)
)
return
decoder_outputs
return
decoder_outputs
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
self
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
Optional
[
NestedTensors
]:
# TODO: This method does not obey the interface for SupportsMultiModal.
# TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1.
# Refactor this once encoder/decoder support is implemented in V1.
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
return
self
.
model
.
get_encoder_outputs
(
return
self
.
model
.
get_encoder_outputs
(
audio_input
[
"input_features"
])
audio_input
[
"input_features"
],
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
NestedTensors
]
=
None
,
multimodal_embeddings
:
Optional
[
NestedTensors
]
=
None
,
attn_metadata
:
Optional
[
AttentionMetadata
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO: This method just returns the decoder sequence embeddings since
# TODO: This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. Refactor this once
# Whisper does not have encoder text tokens. Refactor this once
...
...
vllm/spec_decode/draft_model_runner.py
View file @
cdc1fa12
...
@@ -288,8 +288,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
...
@@ -288,8 +288,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
hidden_states
=
model_executable
(
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
model_input
.
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
device
=
self
.
device
),
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cdc1fa12
...
@@ -939,8 +939,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -939,8 +939,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
self
.
kv_caches
,
attn_metadata
=
None
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
...
@@ -1137,11 +1135,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1137,11 +1135,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_dummy_run
(
def
_dummy_run
(
self
,
self
,
num_tokens
:
int
,
num_tokens
:
int
,
kv_caches
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
model
=
self
.
model
model
=
self
.
model
if
kv_caches
is
None
:
kv_caches
=
self
.
kv_caches
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
input_ids
=
None
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
...
@@ -1172,26 +1167,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1172,26 +1167,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
=
model
(
hidden_states
=
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
None
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
return
hidden_states
return
hidden_states
def
profile_run
(
self
)
->
None
:
def
profile_run
(
self
)
->
None
:
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
dummy_kv_caches
=
[
torch
.
tensor
((),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
self
.
num_attn_layers
)
]
# Profile with multimodal encoder & encoder cache.
# Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them.
# TODO: handle encoder-decoder models once we support them.
if
(
self
.
is_multimodal_model
and
self
.
max_num_encoder_input_tokens
>
0
if
(
self
.
is_multimodal_model
and
self
.
max_num_encoder_input_tokens
>
0
...
@@ -1302,8 +1283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1302,8 +1283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with
self
.
maybe_profile_with_lora
(
self
.
lora_config
,
with
self
.
maybe_profile_with_lora
(
self
.
lora_config
,
num_scheduled_tokens
):
num_scheduled_tokens
):
# Trigger compilation for general shape.
# Trigger compilation for general shape.
hidden_states
=
self
.
_dummy_run
(
self
.
max_num_tokens
,
hidden_states
=
self
.
_dummy_run
(
self
.
max_num_tokens
)
dummy_kv_caches
)
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
hidden_states
=
hidden_states
[
logit_indices
]
hidden_states
=
hidden_states
[
logit_indices
]
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
cdc1fa12
...
@@ -13,11 +13,10 @@ import torch.nn as nn
...
@@ -13,11 +13,10 @@ import torch.nn as nn
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
import
torch_xla.runtime
as
xr
from
vllm.attention
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
...
@@ -623,7 +622,6 @@ class TPUModelRunner:
...
@@ -623,7 +622,6 @@ class TPUModelRunner:
assert
self
.
model
is
not
None
assert
self
.
model
is
not
None
selected_token_ids
=
self
.
model
(
prompt_data
.
input_tokens
,
selected_token_ids
=
self
.
model
(
prompt_data
.
input_tokens
,
prompt_data
.
input_positions
,
prompt_data
.
input_positions
,
prompt_data
.
attn_metadata
,
self
.
kv_caches
)
self
.
kv_caches
)
# In parallel to TPU execution, prepare the next iteration
# In parallel to TPU execution, prepare the next iteration
...
@@ -662,7 +660,6 @@ class TPUModelRunner:
...
@@ -662,7 +660,6 @@ class TPUModelRunner:
assert
self
.
model
is
not
None
assert
self
.
model
is
not
None
selected_token_ids
=
self
.
model
(
decode_data
.
input_tokens
,
selected_token_ids
=
self
.
model
(
decode_data
.
input_tokens
,
decode_data
.
input_positions
,
decode_data
.
input_positions
,
decode_data
.
attn_metadata
,
self
.
kv_caches
)
self
.
kv_caches
)
# Transfer sampled tokens from TPU to CPU
# Transfer sampled tokens from TPU to CPU
...
@@ -839,7 +836,7 @@ class TPUModelRunner:
...
@@ -839,7 +836,7 @@ class TPUModelRunner:
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
assert
self
.
model
is
not
None
assert
self
.
model
is
not
None
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
kv_caches
)
self
.
model
(
token_ids
,
position_ids
,
kv_caches
)
def
capture_model
(
self
)
->
None
:
def
capture_model
(
self
)
->
None
:
"""Compile the model."""
"""Compile the model."""
...
@@ -963,7 +960,6 @@ class ModelWrapperV1(nn.Module):
...
@@ -963,7 +960,6 @@ class ModelWrapperV1(nn.Module):
self
,
self
,
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
"""Executes the forward pass of the model and samples the next token.
...
@@ -971,7 +967,6 @@ class ModelWrapperV1(nn.Module):
...
@@ -971,7 +967,6 @@ class ModelWrapperV1(nn.Module):
Args:
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
p: The top-p probability of shape [batch_size].
...
@@ -980,7 +975,8 @@ class ModelWrapperV1(nn.Module):
...
@@ -980,7 +975,8 @@ class ModelWrapperV1(nn.Module):
memory profiling at initialization.
memory profiling at initialization.
"""
"""
# Skip this in memory profiling at initialization.
# Skip this in memory profiling at initialization.
if
attn_metadata
is
not
None
and
kv_caches
[
0
][
0
].
numel
()
>
0
:
if
kv_caches
[
0
][
0
].
numel
()
>
0
:
attn_metadata
=
get_forward_context
().
attn_metadata
# index_copy_(slot_mapping) only works when the inserted dimension
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
...
@@ -1001,12 +997,7 @@ class ModelWrapperV1(nn.Module):
...
@@ -1001,12 +997,7 @@ class ModelWrapperV1(nn.Module):
attn_metadata
.
slot_mapping
=
slot_mapping
attn_metadata
.
slot_mapping
=
slot_mapping
assert
self
.
model
is
not
None
assert
self
.
model
is
not
None
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
token_ids
,
position_ids
)
token_ids
,
position_ids
,
kv_caches
,
attn_metadata
,
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
...
...
vllm/worker/cpu_enc_dec_model_runner.py
View file @
cdc1fa12
...
@@ -297,10 +297,6 @@ class CPUEncoderDecoderModelRunner(
...
@@ -297,10 +297,6 @@ class CPUEncoderDecoderModelRunner(
model_input
.
encoder_input_tokens
,
model_input
.
encoder_input_tokens
,
"encoder_positions"
:
"encoder_positions"
:
model_input
.
encoder_input_positions
,
model_input
.
encoder_input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
model_input
.
attn_metadata
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
),
device
=
self
.
device
),
"intermediate_tensors"
:
"intermediate_tensors"
:
...
...
vllm/worker/cpu_model_runner.py
View file @
cdc1fa12
...
@@ -654,8 +654,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
...
@@ -654,8 +654,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
hidden_states
=
model_executable
(
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
model_input
.
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
execute_model_kwargs
,
**
execute_model_kwargs
,
**
multimodal_kwargs
,
**
multimodal_kwargs
,
...
...
vllm/worker/cpu_pooling_model_runner.py
View file @
cdc1fa12
...
@@ -41,16 +41,6 @@ class CPUPoolingModelRunner(
...
@@ -41,16 +41,6 @@ class CPUPoolingModelRunner(
raise
ValueError
(
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
"CPU worker does not support multi-step execution."
)
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches
=
[
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
num_layers
)
]
model_executable
=
self
.
model
model_executable
=
self
.
model
cross_enc_kwargs
=
{}
cross_enc_kwargs
=
{}
if
model_input
.
token_type_ids
is
not
None
:
if
model_input
.
token_type_ids
is
not
None
:
...
@@ -60,10 +50,6 @@ class CPUPoolingModelRunner(
...
@@ -60,10 +50,6 @@ class CPUPoolingModelRunner(
model_input
.
input_tokens
,
model_input
.
input_tokens
,
"positions"
:
"positions"
:
model_input
.
input_positions
,
model_input
.
input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
model_input
.
attn_metadata
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
),
device
=
self
.
device
),
**
cross_enc_kwargs
,
**
cross_enc_kwargs
,
...
...
vllm/worker/enc_dec_model_runner.py
View file @
cdc1fa12
...
@@ -184,8 +184,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -184,8 +184,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
encoder_input_ids
=
model_input
.
encoder_input_tokens
,
encoder_input_ids
=
model_input
.
encoder_input_tokens
,
encoder_positions
=
model_input
.
encoder_input_positions
,
encoder_positions
=
model_input
.
encoder_input_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
model_input
.
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
device
=
self
.
device
),
...
@@ -324,21 +322,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -324,21 +322,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
or
encoder_dummy_data
.
multi_modal_placeholders
)
or
encoder_dummy_data
.
multi_modal_placeholders
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
# Run the model with the dummy inputs.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches
=
[
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
num_layers
)
]
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
seqs
,
finished_requests_ids
=
finished_requests_ids
)
intermediate_tensors
=
None
intermediate_tensors
=
None
self
.
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
)
self
.
execute_model
(
model_input
,
None
,
intermediate_tensors
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
return
return
...
...
vllm/worker/hpu_model_runner.py
View file @
cdc1fa12
...
@@ -384,11 +384,12 @@ class HpuModelAdapter:
...
@@ -384,11 +384,12 @@ class HpuModelAdapter:
if
'virtual_engine'
in
kwargs
:
if
'virtual_engine'
in
kwargs
:
virtual_engine
=
kwargs
.
pop
(
'virtual_engine'
)
virtual_engine
=
kwargs
.
pop
(
'virtual_engine'
)
input_ids
=
kwargs
[
'input_ids'
]
input_ids
=
kwargs
[
'input_ids'
]
kwargs
[
'attn_metadata'
]
=
self
.
_update_metadata
(
attn_metadata
=
self
.
_update_metadata
(
kwargs
.
pop
(
'attn_metadata'
),
kwargs
[
'attn_metadata'
],
input_ids
.
size
(
0
),
input_ids
.
size
(
1
),
input_ids
.
size
(
0
),
input_ids
.
device
,
self
.
dtype
)
input_ids
.
size
(
1
),
input_ids
.
device
,
self
.
dtype
)
LoraMask
.
setLoraMask
(
kwargs
.
pop
(
'lora_mask'
))
LoraMask
.
setLoraMask
(
kwargs
.
pop
(
'lora_mask'
))
with
set_forward_context
(
kwargs
[
'
attn_metadata
'
]
,
self
.
vllm_config
,
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
virtual_engine
):
virtual_engine
):
hidden_states
=
self
.
model
(
*
args
,
**
kwargs
)
hidden_states
=
self
.
model
(
*
args
,
**
kwargs
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
...
@@ -1346,15 +1347,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1346,15 +1347,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
max_seq_len
=
self
.
bucketing_global_state
.
prompt_seq_bucket_cfg
[
-
1
]
max_seq_len
=
self
.
bucketing_global_state
.
prompt_seq_bucket_cfg
[
-
1
]
max_batch_size
=
min
(
self
.
max_num_batched_tokens
//
max_seq_len
,
max_batch_size
=
min
(
self
.
max_num_batched_tokens
//
max_seq_len
,
self
.
scheduler_config
.
max_num_seqs
)
self
.
scheduler_config
.
max_num_seqs
)
self
.
warmup_scenario
(
max_batch_size
,
max_seq_len
,
True
,
kv_caches
,
self
.
warmup_scenario
(
max_batch_size
,
max_seq_len
,
True
,
False
,
True
)
False
,
True
)
return
return
def
warmup_scenario
(
self
,
def
warmup_scenario
(
self
,
batch_size
,
batch_size
,
seq_len
,
seq_len
,
is_prompt
,
is_prompt
,
kv_caches
,
is_pt_profiler_run
=
False
,
is_pt_profiler_run
=
False
,
is_lora_profile_run
=
False
)
->
None
:
is_lora_profile_run
=
False
)
->
None
:
use_graphs
=
self
.
_use_graphs
(
batch_size
,
seq_len
,
is_prompt
)
use_graphs
=
self
.
_use_graphs
(
batch_size
,
seq_len
,
is_prompt
)
...
@@ -1418,7 +1417,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1418,7 +1417,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler
.
start
()
profiler
.
start
()
for
_
in
range
(
times
):
for
_
in
range
(
times
):
inputs
=
self
.
prepare_model_input
(
seqs
)
inputs
=
self
.
prepare_model_input
(
seqs
)
self
.
execute_model
(
inputs
,
kv_caches
,
warmup_mode
=
True
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
)
torch
.
hpu
.
synchronize
()
torch
.
hpu
.
synchronize
()
if
profiler
:
if
profiler
:
profiler
.
step
()
profiler
.
step
()
...
@@ -1470,17 +1469,16 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1470,17 +1469,16 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
f
"free_mem:
{
free_mem
}
"
)
f
"free_mem:
{
free_mem
}
"
)
logger
.
info
(
msg
)
logger
.
info
(
msg
)
def
warmup_all_buckets
(
self
,
buckets
,
is_prompt
,
kv_caches
):
def
warmup_all_buckets
(
self
,
buckets
,
is_prompt
):
for
i
,
(
batch_size
,
seq_len
)
in
enumerate
(
reversed
(
buckets
)):
for
i
,
(
batch_size
,
seq_len
)
in
enumerate
(
reversed
(
buckets
)):
self
.
log_warmup
(
'Prompt'
if
is_prompt
else
'Decode'
,
i
,
self
.
log_warmup
(
'Prompt'
if
is_prompt
else
'Decode'
,
i
,
len
(
buckets
),
batch_size
,
seq_len
)
len
(
buckets
),
batch_size
,
seq_len
)
self
.
warmup_scenario
(
batch_size
,
seq_len
,
is_prompt
,
kv_caches
)
self
.
warmup_scenario
(
batch_size
,
seq_len
,
is_prompt
)
def
warmup_graphs
(
self
,
def
warmup_graphs
(
self
,
strategy
,
strategy
,
buckets
,
buckets
,
is_prompt
,
is_prompt
,
kv_caches
,
available_mem
,
available_mem
,
starting_mem
=
0
,
starting_mem
=
0
,
total_batch_seq
=
0.001
):
total_batch_seq
=
0.001
):
...
@@ -1512,7 +1510,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1512,7 +1510,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self
.
graphed_buckets
.
add
(
graphed_bucket
)
self
.
graphed_buckets
.
add
(
graphed_bucket
)
self
.
log_warmup
(
phase
,
idx
,
num_candidates
,
batch_size
,
seq_len
)
self
.
log_warmup
(
phase
,
idx
,
num_candidates
,
batch_size
,
seq_len
)
with
HabanaMemoryProfiler
()
as
mem_prof
:
with
HabanaMemoryProfiler
()
as
mem_prof
:
self
.
warmup_scenario
(
batch_size
,
seq_len
,
is_prompt
,
kv_caches
)
self
.
warmup_scenario
(
batch_size
,
seq_len
,
is_prompt
)
used_mem
=
align_workers
(
mem_prof
.
consumed_device_memory
,
used_mem
=
align_workers
(
mem_prof
.
consumed_device_memory
,
torch
.
distributed
.
ReduceOp
.
MAX
)
torch
.
distributed
.
ReduceOp
.
MAX
)
available_mem
-=
used_mem
available_mem
-=
used_mem
...
@@ -1542,8 +1540,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1542,8 +1540,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
graphs
=
graph
==
't'
graphs
=
graph
==
't'
if
graphs
:
if
graphs
:
self
.
graphed_buckets
.
add
((
int
(
bs
),
int
(
seq_len
),
is_prompt
))
self
.
graphed_buckets
.
add
((
int
(
bs
),
int
(
seq_len
),
is_prompt
))
self
.
warmup_scenario
(
int
(
bs
),
int
(
seq_len
),
is_prompt
,
kv_caches
,
self
.
warmup_scenario
(
int
(
bs
),
int
(
seq_len
),
is_prompt
,
True
)
True
)
raise
AssertionError
(
"Finished profiling"
)
raise
AssertionError
(
"Finished profiling"
)
if
self
.
skip_warmup
:
if
self
.
skip_warmup
:
logger
.
info
(
"Skipping warmup..."
)
logger
.
info
(
"Skipping warmup..."
)
...
@@ -1608,9 +1605,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1608,9 +1605,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
with
compile_only_mode_context
(
with
compile_only_mode_context
(
)
if
can_use_compile_only_mode
else
contextlib
.
nullcontext
():
)
if
can_use_compile_only_mode
else
contextlib
.
nullcontext
():
self
.
warmup_all_buckets
(
self
.
bucketing_global_state
.
prompt_buckets
,
self
.
warmup_all_buckets
(
self
.
bucketing_global_state
.
prompt_buckets
,
True
,
kv_caches
)
True
)
self
.
warmup_all_buckets
(
self
.
bucketing_global_state
.
decode_buckets
,
self
.
warmup_all_buckets
(
self
.
bucketing_global_state
.
decode_buckets
,
False
,
kv_caches
)
False
)
if
not
self
.
enforce_eager
and
htorch
.
utils
.
internal
.
is_lazy
():
if
not
self
.
enforce_eager
and
htorch
.
utils
.
internal
.
is_lazy
():
assert
self
.
mem_margin
is
not
None
,
\
assert
self
.
mem_margin
is
not
None
,
\
...
@@ -1641,11 +1638,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1641,11 +1638,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mem_post_prompt
,
prompt_batch_seq
,
prompt_captured_all
=
\
mem_post_prompt
,
prompt_batch_seq
,
prompt_captured_all
=
\
self
.
warmup_graphs
(
self
.
warmup_graphs
(
prompt_strategy
,
self
.
bucketing_global_state
.
prompt_buckets
,
prompt_strategy
,
self
.
bucketing_global_state
.
prompt_buckets
,
True
,
kv_caches
,
prompt_available_memory
)
True
,
prompt_available_memory
)
mem_post_decode
,
decode_batch_seq
,
decode_captured_all
=
\
mem_post_decode
,
decode_batch_seq
,
decode_captured_all
=
\
self
.
warmup_graphs
(
self
.
warmup_graphs
(
decode_strategy
,
self
.
bucketing_global_state
.
decode_buckets
,
decode_strategy
,
self
.
bucketing_global_state
.
decode_buckets
,
False
,
kv_caches
,
decode_available_memory
)
False
,
decode_available_memory
)
# Not all prompt buckets were captured, but all decode buckets
# Not all prompt buckets were captured, but all decode buckets
# were captured and we have some free graph-allocated space
# were captured and we have some free graph-allocated space
...
@@ -1656,7 +1653,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1656,7 +1653,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self
.
warmup_graphs
(
self
.
warmup_graphs
(
prompt_strategy
,
prompt_strategy
,
self
.
bucketing_global_state
.
prompt_buckets
,
True
,
self
.
bucketing_global_state
.
prompt_buckets
,
True
,
kv_caches
,
graph_free_mem
-
mem_post_prompt
-
mem_post_decode
,
graph_free_mem
-
mem_post_prompt
-
mem_post_decode
,
mem_post_prompt
,
prompt_batch_seq
))
mem_post_prompt
,
prompt_batch_seq
))
...
@@ -1669,7 +1665,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1669,7 +1665,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mem_post_decode
,
_
,
_
=
self
.
warmup_graphs
(
mem_post_decode
,
_
,
_
=
self
.
warmup_graphs
(
decode_strategy
,
decode_strategy
,
self
.
bucketing_global_state
.
decode_buckets
,
False
,
self
.
bucketing_global_state
.
decode_buckets
,
False
,
kv_caches
,
graph_free_mem
-
mem_post_prompt
-
mem_post_decode
,
graph_free_mem
-
mem_post_prompt
-
mem_post_decode
,
mem_post_decode
,
decode_batch_seq
)
mem_post_decode
,
decode_batch_seq
)
...
@@ -1982,7 +1977,6 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
...
@@ -1982,7 +1977,6 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
execute_model_kwargs
=
{
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"positions"
:
input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
attn_metadata
),
"attn_metadata"
:
self
.
trim_attn_metadata
(
attn_metadata
),
"intermediate_tensors"
:
intermediate_tensors
,
"intermediate_tensors"
:
intermediate_tensors
,
"lora_mask"
:
lora_mask
,
"lora_mask"
:
lora_mask
,
...
...
vllm/worker/model_runner.py
View file @
cdc1fa12
...
@@ -26,7 +26,7 @@ from vllm.core.scheduler import SchedulerOutputs
...
@@ -26,7 +26,7 @@ from vllm.core.scheduler import SchedulerOutputs
from
vllm.distributed
import
get_kv_transfer_group
,
get_pp_group
from
vllm.distributed
import
get_kv_transfer_group
,
get_pp_group
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
graph_capture
)
graph_capture
)
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.layers
import
LoRAMapping
...
@@ -1727,8 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1727,8 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
hidden_or_intermediate_states
=
model_executable
(
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
model_input
.
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
device
=
self
.
device
),
...
@@ -1913,8 +1911,6 @@ class CUDAGraphRunner(nn.Module):
...
@@ -1913,8 +1911,6 @@ class CUDAGraphRunner(nn.Module):
self
.
model
(
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_inputs
,
intermediate_tensors
=
intermediate_inputs
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -1927,8 +1923,6 @@ class CUDAGraphRunner(nn.Module):
...
@@ -1927,8 +1923,6 @@ class CUDAGraphRunner(nn.Module):
output_hidden_or_intermediate_states
=
self
.
model
(
output_hidden_or_intermediate_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_inputs
,
intermediate_tensors
=
intermediate_inputs
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -1976,13 +1970,10 @@ class CUDAGraphRunner(nn.Module):
...
@@ -1976,13 +1970,10 @@ class CUDAGraphRunner(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# KV caches are fixed tensors, so we don't need to copy them.
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
del
kv_caches
# Copy the input tensors to the input buffers.
# Copy the input tensors to the input buffers.
self
.
input_buffers
[
"input_ids"
].
copy_
(
input_ids
,
non_blocking
=
True
)
self
.
input_buffers
[
"input_ids"
].
copy_
(
input_ids
,
non_blocking
=
True
)
...
...
vllm/worker/multi_step_model_runner.py
View file @
cdc1fa12
...
@@ -476,7 +476,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -476,7 +476,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# path for warm up runs
# path for warm up runs
if
not
model_input
.
is_multi_step
:
if
not
model_input
.
is_multi_step
:
return
self
.
_base_model_runner
.
execute_model
(
return
self
.
_base_model_runner
.
execute_model
(
frozen_model_input
,
kv_caches
,
intermediate_tensors
,
num_steps
)
frozen_model_input
,
None
,
intermediate_tensors
,
num_steps
)
# make sure we skip the sampler on the lask rank and only pythonize
# make sure we skip the sampler on the lask rank and only pythonize
# if CPU is ahead.
# if CPU is ahead.
...
@@ -538,7 +538,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -538,7 +538,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Execute the model
# Execute the model
output
=
self
.
_base_model_runner
.
execute_model
(
frozen_model_input
,
output
=
self
.
_base_model_runner
.
execute_model
(
frozen_model_input
,
kv_caches
,
None
,
intermediate_tensors
,
intermediate_tensors
,
num_steps
=
1
)
num_steps
=
1
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment