Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
533d2a1f
Unverified
Commit
533d2a1f
authored
Apr 18, 2024
by
SangBin Cho
Committed by
GitHub
Apr 17, 2024
Browse files
[Typing] Mypy typing part 2 (#4043)
Co-authored-by:
SangBin Cho
<
sangcho@sangcho-LT93GQWG9C.local
>
parent
a5322254
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
180 additions
and
126 deletions
+180
-126
.github/workflows/mypy.yaml
.github/workflows/mypy.yaml
+4
-4
format.sh
format.sh
+4
-4
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+25
-19
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+2
-2
vllm/model_executor/guided_decoding/outlines_decoding.py
vllm/model_executor/guided_decoding/outlines_decoding.py
+2
-2
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+5
-1
vllm/model_executor/model_loader/neuron.py
vllm/model_executor/model_loader/neuron.py
+9
-7
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+1
-0
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+4
-0
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+3
-3
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+2
-2
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+1
-0
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+11
-10
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+5
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+15
-10
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+6
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+56
-40
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+14
-8
vllm/worker/worker.py
vllm/worker/worker.py
+7
-4
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+4
-4
No files found.
.github/workflows/mypy.yaml
View file @
533d2a1f
...
@@ -41,10 +41,10 @@ jobs:
...
@@ -41,10 +41,10 @@ jobs:
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
format.sh
View file @
533d2a1f
...
@@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
...
@@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/
*
.py
--follow-imports
=
skip
--config-file
pyproject.toml
mypy vllm/transformers_utils/
*
.py
--follow-imports
=
skip
--config-file
pyproject.toml
# TODO(sang): Follow up
# TODO(sang): Follow up
#
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/engine/
*
.py
--follow-imports
=
skip
--config-file
pyproject.toml
#
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/
*
.py
--follow-imports
=
skip
--config-file
pyproject.toml
#
mypy vllm/spec_decod
ing
/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decod
e
/
*
.py
--follow-imports
=
skip
--config-file
pyproject.toml
#
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/
*
.py
--follow-imports
=
skip
--config-file
pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
...
...
vllm/engine/async_llm_engine.py
View file @
533d2a1f
...
@@ -2,8 +2,8 @@ import asyncio
...
@@ -2,8 +2,8 @@ import asyncio
import
os
import
os
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
AsyncIterator
,
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
from
typing
import
(
Any
,
AsyncIterator
,
Callable
,
Dict
,
Iterable
,
List
,
Set
,
Tuple
,
Type
,
Union
)
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
...
@@ -52,7 +52,7 @@ class AsyncStream:
...
@@ -52,7 +52,7 @@ class AsyncStream:
def
__init__
(
self
,
request_id
:
str
)
->
None
:
def
__init__
(
self
,
request_id
:
str
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
_queue
=
asyncio
.
Queue
()
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
self
.
_finished
=
False
def
put
(
self
,
item
:
Union
[
RequestOutput
,
Exception
])
->
None
:
def
put
(
self
,
item
:
Union
[
RequestOutput
,
Exception
])
->
None
:
...
@@ -312,15 +312,17 @@ class AsyncLLMEngine:
...
@@ -312,15 +312,17 @@ class AsyncLLMEngine:
self
.
max_log_len
=
max_log_len
self
.
max_log_len
=
max_log_len
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
self
.
background_loop
=
None
self
.
background_loop
:
Optional
[
asyncio
.
Future
]
=
None
# We need to keep a reference to unshielded
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# task as well to prevent it from being garbage
# collected
# collected
self
.
_background_loop_unshielded
=
None
self
.
_background_loop_unshielded
:
Optional
[
asyncio
.
Task
[
Any
]]
=
None
self
.
start_engine_loop
=
start_engine_loop
self
.
start_engine_loop
=
start_engine_loop
self
.
_request_tracker
:
Optional
[
RequestTracker
]
=
None
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
# Lazy initialized fields
self
.
_request_tracker
:
RequestTracker
@
classmethod
@
classmethod
def
from_engine_args
(
def
from_engine_args
(
cls
,
cls
,
...
@@ -361,11 +363,13 @@ class AsyncLLMEngine:
...
@@ -361,11 +363,13 @@ class AsyncLLMEngine:
@
property
@
property
def
is_running
(
self
)
->
bool
:
def
is_running
(
self
)
->
bool
:
return
(
self
.
background_loop
is
not
None
return
(
self
.
background_loop
is
not
None
and
self
.
_background_loop_unshielded
is
not
None
and
not
self
.
_background_loop_unshielded
.
done
())
and
not
self
.
_background_loop_unshielded
.
done
())
@
property
@
property
def
is_stopped
(
self
)
->
bool
:
def
is_stopped
(
self
)
->
bool
:
return
self
.
errored
or
(
self
.
background_loop
is
not
None
return
self
.
errored
or
(
self
.
background_loop
is
not
None
and
self
.
_background_loop_unshielded
is
not
None
and
self
.
_background_loop_unshielded
.
done
())
and
self
.
_background_loop_unshielded
.
done
())
@
property
@
property
...
@@ -381,7 +385,7 @@ class AsyncLLMEngine:
...
@@ -381,7 +385,7 @@ class AsyncLLMEngine:
async
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
async
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_tokenizer
.
remote
()
return
await
self
.
engine
.
get_tokenizer
.
remote
()
# type: ignore
else
:
else
:
return
self
.
engine
.
get_tokenizer
()
return
self
.
engine
.
get_tokenizer
()
...
@@ -434,7 +438,8 @@ class AsyncLLMEngine:
...
@@ -434,7 +438,8 @@ class AsyncLLMEngine:
# TODO: Maybe add add_request_batch to reduce Ray overhead
# TODO: Maybe add add_request_batch to reduce Ray overhead
try
:
try
:
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
add_request
.
remote
(
**
new_request
)
await
self
.
engine
.
add_request
.
remote
(
# type: ignore
**
new_request
)
else
:
else
:
await
self
.
engine
.
add_request_async
(
**
new_request
)
await
self
.
engine
.
add_request_async
(
**
new_request
)
except
ValueError
as
e
:
except
ValueError
as
e
:
...
@@ -449,7 +454,7 @@ class AsyncLLMEngine:
...
@@ -449,7 +454,7 @@ class AsyncLLMEngine:
await
self
.
_engine_abort
(
finished_requests
)
await
self
.
_engine_abort
(
finished_requests
)
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
request_outputs
=
await
self
.
engine
.
step
.
remote
()
request_outputs
=
await
self
.
engine
.
step
.
remote
()
# type: ignore
else
:
else
:
request_outputs
=
await
self
.
engine
.
step_async
()
request_outputs
=
await
self
.
engine
.
step_async
()
...
@@ -462,7 +467,7 @@ class AsyncLLMEngine:
...
@@ -462,7 +467,7 @@ class AsyncLLMEngine:
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
abort_request
.
remote
(
request_ids
)
await
self
.
engine
.
abort_request
.
remote
(
request_ids
)
# type: ignore
else
:
else
:
self
.
engine
.
abort_request
(
request_ids
)
self
.
engine
.
abort_request
(
request_ids
)
...
@@ -525,11 +530,12 @@ class AsyncLLMEngine:
...
@@ -525,11 +530,12 @@ class AsyncLLMEngine:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
prompt_token_ids
=
await
self
.
engine
.
encode_request_async
.
remote
(
prompt_token_ids
=
await
(
self
.
engine
.
encode_request_async
.
remote
(
# type: ignore
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
lora_request
=
lora_request
)
lora_request
=
lora_request
)
)
else
:
else
:
prompt_token_ids
=
await
self
.
engine
.
encode_request_async
(
prompt_token_ids
=
await
self
.
engine
.
encode_request_async
(
request_id
=
request_id
,
request_id
=
request_id
,
...
@@ -676,13 +682,13 @@ class AsyncLLMEngine:
...
@@ -676,13 +682,13 @@ class AsyncLLMEngine:
async
def
get_model_config
(
self
)
->
ModelConfig
:
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
"""Get the model configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_model_config
.
remote
()
return
await
self
.
engine
.
get_model_config
.
remote
()
# type: ignore
else
:
else
:
return
self
.
engine
.
get_model_config
()
return
self
.
engine
.
get_model_config
()
async
def
do_log_stats
(
self
)
->
None
:
async
def
do_log_stats
(
self
)
->
None
:
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
do_log_stats
.
remote
()
await
self
.
engine
.
do_log_stats
.
remote
()
# type: ignore
else
:
else
:
self
.
engine
.
do_log_stats
()
self
.
engine
.
do_log_stats
()
...
@@ -695,7 +701,7 @@ class AsyncLLMEngine:
...
@@ -695,7 +701,7 @@ class AsyncLLMEngine:
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
try
:
try
:
await
self
.
engine
.
check_health
.
remote
()
await
self
.
engine
.
check_health
.
remote
()
# type: ignore
except
ray
.
exceptions
.
RayActorError
as
e
:
except
ray
.
exceptions
.
RayActorError
as
e
:
raise
RuntimeError
(
"Engine is dead."
)
from
e
raise
RuntimeError
(
"Engine is dead."
)
from
e
else
:
else
:
...
...
vllm/lora/worker_manager.py
View file @
533d2a1f
...
@@ -107,12 +107,12 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
...
@@ -107,12 +107,12 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
self
.
_lora_manager
:
LoRAModelManager
=
lora_manager
self
.
_lora_manager
:
LoRAModelManager
=
lora_manager
return
lora_manager
.
model
return
lora_manager
.
model
def
set_active_loras
(
self
,
lora_requests
:
Lis
t
[
LoRARequest
],
def
set_active_loras
(
self
,
lora_requests
:
Se
t
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
lora_mapping
:
LoRAMapping
)
->
None
:
self
.
_apply_loras
(
lora_requests
)
self
.
_apply_loras
(
lora_requests
)
self
.
_lora_manager
.
set_lora_mapping
(
lora_mapping
)
self
.
_lora_manager
.
set_lora_mapping
(
lora_mapping
)
def
_apply_loras
(
self
,
lora_requests
:
Lis
t
[
LoRARequest
])
->
None
:
def
_apply_loras
(
self
,
lora_requests
:
Se
t
[
LoRARequest
])
->
None
:
loras_that_exist
=
self
.
list_loras
()
loras_that_exist
=
self
.
list_loras
()
loras_map
=
{
loras_map
=
{
lora_request
.
lora_int_id
:
lora_request
lora_request
.
lora_int_id
:
lora_request
...
...
vllm/model_executor/guided_decoding/outlines_decoding.py
View file @
533d2a1f
...
@@ -55,7 +55,7 @@ global_thread_pool = None # used for generating logits processor fsm
...
@@ -55,7 +55,7 @@ global_thread_pool = None # used for generating logits processor fsm
async
def
get_outlines_guided_decoding_logits_processor
(
async
def
get_outlines_guided_decoding_logits_processor
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
]:
tokenizer
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
,
None
]:
"""
"""
Given an OpenAI-compatible request, check for guided decoding parameters
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
and get the necessary logits processor for the given guide.
...
@@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor(
...
@@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor(
def
_get_guide_and_mode
(
def
_get_guide_and_mode
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
)
->
Tuple
[
str
,
GuidedDecodingMode
]:
)
->
Union
[
Tuple
[
str
,
GuidedDecodingMode
]
,
Tuple
[
None
,
None
]]
:
if
request
.
guided_json
:
if
request
.
guided_json
:
json
=
request
.
guided_json
json
=
request
.
guided_json
...
...
vllm/model_executor/guided_decoding/outlines_logits_processors.py
View file @
533d2a1f
...
@@ -21,7 +21,7 @@ from functools import lru_cache
...
@@ -21,7 +21,7 @@ from functools import lru_cache
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
outlines.fsm.fsm
import
CFGFSM
,
RegexFSM
from
outlines.fsm.fsm
import
CFGFSM
,
FSM
,
RegexFSM
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
...
@@ -29,6 +29,10 @@ from transformers import PreTrainedTokenizerBase
...
@@ -29,6 +29,10 @@ from transformers import PreTrainedTokenizerBase
class
BaseLogitsProcessor
:
class
BaseLogitsProcessor
:
def
__init__
(
self
):
# Child class should use initialize in their init.
self
.
fsm
:
FSM
def
init_state
(
self
):
def
init_state
(
self
):
"""Initialize the FSM states."""
"""Initialize the FSM states."""
self
.
fsm_state
:
DefaultDict
[
int
,
int
]
=
defaultdict
(
int
)
self
.
fsm_state
:
DefaultDict
[
int
,
int
]
=
defaultdict
(
int
)
...
...
vllm/model_executor/model_loader/neuron.py
View file @
533d2a1f
"""Utilities for selecting and loading neuron models."""
"""Utilities for selecting and loading neuron models."""
import
importlib
import
importlib
import
os
import
os
from
typing
import
Optional
,
T
yp
e
from
typing
import
Dict
,
Optional
,
T
upl
e
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -27,7 +27,7 @@ TORCH_DTYPE_TO_NEURON_AMP = {
...
@@ -27,7 +27,7 @@ TORCH_DTYPE_TO_NEURON_AMP = {
}
}
# Models supported by Neuron.
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS
=
{
_NEURON_SUPPORTED_MODELS
:
Dict
[
str
,
Tuple
[
str
,
str
,
str
]]
=
{
"LlamaForCausalLM"
:
(
"transformers_neuronx.llama.model"
,
"LlamaForCausalLM"
:
(
"transformers_neuronx.llama.model"
,
"LlamaForSampling"
,
"LlamaForCausalLM"
),
"LlamaForSampling"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"transformers_neuronx.mistral.model"
,
"MistralForCausalLM"
:
(
"transformers_neuronx.mistral.model"
,
...
@@ -43,11 +43,13 @@ class NeuronCasualLM(nn.Module):
...
@@ -43,11 +43,13 @@ class NeuronCasualLM(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
model
=
None
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
logits_as_input
=
True
)
logits_as_input
=
True
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
# Lazy initialized
self
.
model
:
nn
.
Module
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -74,17 +76,17 @@ class NeuronCasualLM(nn.Module):
...
@@ -74,17 +76,17 @@ class NeuronCasualLM(nn.Module):
def
load_weights
(
self
,
model_name_or_path
:
str
,
**
kwargs
):
def
load_weights
(
self
,
model_name_or_path
:
str
,
**
kwargs
):
arch
=
_get_model_architecture
(
self
.
config
)
arch
=
_get_model_architecture
(
self
.
config
)
neuronx_module_path
,
neuronx_model_cls
,
hf_model_cls
=
(
neuronx_module_path
,
neuronx_model_cls
_name
,
hf_model_cls
_name
=
(
_NEURON_SUPPORTED_MODELS
[
arch
])
_NEURON_SUPPORTED_MODELS
[
arch
])
neuronx_module
=
importlib
.
import_module
(
neuronx_module_path
)
neuronx_module
=
importlib
.
import_module
(
neuronx_module_path
)
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls
)
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls
_name
)
split_model_dir
=
f
"
{
model_name_or_path
}
-split"
split_model_dir
=
f
"
{
model_name_or_path
}
-split"
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
"pytorch_model.bin"
)):
"pytorch_model.bin"
)):
split_model_dir
=
model_name_or_path
split_model_dir
=
model_name_or_path
elif
not
os
.
path
.
exists
(
f
"
{
model_name_or_path
}
-split"
):
elif
not
os
.
path
.
exists
(
f
"
{
model_name_or_path
}
-split"
):
hf_model_cls
=
getattr
(
transformers
,
hf_model_cls
)
hf_model_cls
=
getattr
(
transformers
,
hf_model_cls
_name
)
from
transformers_neuronx.module
import
save_pretrained_split
from
transformers_neuronx.module
import
save_pretrained_split
hf_model
=
hf_model_cls
.
from_pretrained
(
model_name_or_path
,
hf_model
=
hf_model_cls
.
from_pretrained
(
model_name_or_path
,
...
@@ -96,7 +98,7 @@ class NeuronCasualLM(nn.Module):
...
@@ -96,7 +98,7 @@ class NeuronCasualLM(nn.Module):
self
.
model
.
to_neuron
()
self
.
model
.
to_neuron
()
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]
:
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
str
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
for
arch
in
architectures
:
if
arch
in
_NEURON_SUPPORTED_MODELS
:
if
arch
in
_NEURON_SUPPORTED_MODELS
:
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
533d2a1f
...
@@ -167,6 +167,7 @@ class TensorizerArgs:
...
@@ -167,6 +167,7 @@ class TensorizerArgs:
decryption_params
=
DecryptionParams
.
from_key
(
key
)
decryption_params
=
DecryptionParams
.
from_key
(
key
)
self
.
deserializer_params
[
'encryption'
]
=
decryption_params
self
.
deserializer_params
[
'encryption'
]
=
decryption_params
@
staticmethod
def
add_cli_args
(
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
)
->
argparse
.
ArgumentParser
:
parser
:
argparse
.
ArgumentParser
)
->
argparse
.
ArgumentParser
:
"""Tensorizer CLI arguments"""
"""Tensorizer CLI arguments"""
...
...
vllm/model_executor/sampling_metadata.py
View file @
533d2a1f
...
@@ -113,6 +113,8 @@ class SamplingTensors:
...
@@ -113,6 +113,8 @@ class SamplingTensors:
get_num_triton_sampler_splits
(
vocab_size
))
get_num_triton_sampler_splits
(
vocab_size
))
sample_indices_start_idx
=
0
sample_indices_start_idx
=
0
assert
sampling_metadata
.
seq_groups
is
not
None
assert
sampling_metadata
.
seq_data
is
not
None
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
temperature
=
sampling_params
.
temperature
...
@@ -147,6 +149,7 @@ class SamplingTensors:
...
@@ -147,6 +149,7 @@ class SamplingTensors:
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get
# For tokens in the prompt that we only need to get
# their logprobs
# their logprobs
assert
sampling_metadata
.
prompt_lens
is
not
None
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
...
@@ -172,6 +175,7 @@ class SamplingTensors:
...
@@ -172,6 +175,7 @@ class SamplingTensors:
is_prompt
=
i
<
sampling_metadata
.
num_prompts
is_prompt
=
i
<
sampling_metadata
.
num_prompts
if
is_prompt
:
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
prompt_best_of
.
append
(
sampling_params
.
best_of
)
assert
sampling_metadata
.
prompt_lens
is
not
None
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
if
sampling_params
.
prompt_logprobs
is
not
None
:
...
...
vllm/spec_decode/batch_expansion.py
View file @
533d2a1f
...
@@ -106,7 +106,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -106,7 +106,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def
_expand_batch
(
def
_expand_batch
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_token_ids_list
:
List
[
TokenId
],
proposal_token_ids_list
:
List
[
List
[
TokenId
]
]
,
proposal_lens_list
:
List
[
int
],
proposal_lens_list
:
List
[
int
],
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
SequenceGroupMetadata
],
int
]:
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
SequenceGroupMetadata
],
int
]:
"""Given the input sequences and potentially multiple corresponding
"""Given the input sequences and potentially multiple corresponding
...
@@ -218,7 +218,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -218,7 +218,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def
_create_target_seq_group_metadata
(
def
_create_target_seq_group_metadata
(
self
,
self
,
input_seq_group_metadata
:
SequenceGroupMetadata
,
input_seq_group_metadata
:
SequenceGroupMetadata
,
proposal_token_ids
:
List
[
TokenId
],
# shape: [batch_size, k]
proposal_token_ids
:
List
[
List
[
TokenId
]
]
,
# shape: [batch_size, k]
batch_index
:
int
,
batch_index
:
int
,
target_seq_ids_iter
:
Iterator
[
TargetSeqId
],
target_seq_ids_iter
:
Iterator
[
TargetSeqId
],
)
->
List
[
SequenceGroupMetadata
]:
)
->
List
[
SequenceGroupMetadata
]:
...
@@ -360,7 +360,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -360,7 +360,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
[0, 1, 2]
[0, 1, 2]
[0, 1, 2, 3]
[0, 1, 2, 3]
"""
"""
empty_token_ids
=
[]
empty_token_ids
:
List
[
TokenId
]
=
[]
token_ids_to_score
=
[
empty_token_ids
]
token_ids_to_score
=
[
empty_token_ids
]
token_ids_to_score
.
extend
([
token_ids_to_score
.
extend
([
...
...
vllm/spec_decode/interfaces.py
View file @
533d2a1f
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch
...
@@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
...
@@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
k
:
int
,
proposals
:
SpeculativeProposals
,
proposals
:
SpeculativeProposals
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
SpeculativeScores
:
raise
NotImplementedError
raise
NotImplementedError
vllm/spec_decode/metrics.py
View file @
533d2a1f
...
@@ -112,6 +112,7 @@ class AsyncMetricsCollector:
...
@@ -112,6 +112,7 @@ class AsyncMetricsCollector:
Returns a CUDA event recording when the copy is complete.
Returns a CUDA event recording when the copy is complete.
"""
"""
assert
self
.
_copy_stream
is
not
None
self
.
_copy_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
_copy_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_copy_stream
):
with
torch
.
cuda
.
stream
(
self
.
_copy_stream
):
...
...
vllm/spec_decode/multi_step_worker.py
View file @
533d2a1f
...
@@ -26,7 +26,8 @@ class MultiStepWorker(Worker):
...
@@ -26,7 +26,8 @@ class MultiStepWorker(Worker):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_proposer
:
Optional
[
DraftModelTop1Proposer
]
=
None
# Lazy initialization list.
self
.
_proposer
:
DraftModelTop1Proposer
def
init_device
(
self
):
def
init_device
(
self
):
super
().
init_device
()
super
().
init_device
()
...
@@ -338,10 +339,10 @@ class DraftModelTop1Proposer(SpeculativeProposer):
...
@@ -338,10 +339,10 @@ class DraftModelTop1Proposer(SpeculativeProposer):
self
.
_vocab_size
,
self
.
_vocab_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
)
device
=
self
.
_device
)
proposal_lens
=
torch
.
zeros
(
len
(
proposal_lens
),
proposal_lens
_tensor
=
torch
.
zeros
(
len
(
proposal_lens
),
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
device
=
self
.
_device
)
return
proposal_tokens
,
proposal_probs
,
proposal_lens
return
proposal_tokens
,
proposal_probs
,
proposal_lens
_tensor
sampler_output
=
maybe_sampler_output
sampler_output
=
maybe_sampler_output
...
@@ -376,9 +377,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
...
@@ -376,9 +377,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
)
entire_proposal_probs
)
proposal_lens
=
torch
.
zeros
(
batch_size
,
proposal_lens
_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
device
=
self
.
_device
)
proposal_lens
[
nonzero_proposal_len_indices
]
=
max_proposal_len
proposal_lens
_tensor
[
nonzero_proposal_len_indices
]
=
max_proposal_len
return
proposal_tokens
,
proposal_probs
,
proposal_lens
return
proposal_tokens
,
proposal_probs
,
proposal_lens
_tensor
vllm/spec_decode/spec_decode_worker.py
View file @
533d2a1f
...
@@ -89,7 +89,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -89,7 +89,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
probs_dtype
=
self
.
rejection_sampler
.
probs_dtype
self
.
probs_dtype
=
self
.
rejection_sampler
.
probs_dtype
self
.
token_id_dtype
=
self
.
rejection_sampler
.
token_id_dtype
self
.
token_id_dtype
=
self
.
rejection_sampler
.
token_id_dtype
self
.
scorer
:
SpeculativeScorer
=
None
# Lazy initiazliation.
self
.
scorer
:
SpeculativeScorer
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
"""Initialize both scorer and proposer models.
...
@@ -233,6 +234,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -233,6 +234,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger
.
info
(
"get spec proposals"
)
logger
.
info
(
"get spec proposals"
)
# Generate proposals using draft worker.
# Generate proposals using draft worker.
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
)
blocks_to_copy
,
k
)
...
...
vllm/worker/cpu_model_runner.py
View file @
533d2a1f
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
...
@@ -48,14 +49,15 @@ class CPUModelRunner:
...
@@ -48,14 +49,15 @@ class CPUModelRunner:
if
device_config
is
not
None
else
DeviceConfig
())
if
device_config
is
not
None
else
DeviceConfig
())
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
self
.
model
=
None
self
.
block_size
=
None
# Set after initial profiling.
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
block_size
:
int
# Set after initial profiling.
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
load_config
=
self
.
load_config
,
...
@@ -245,7 +247,11 @@ class CPUModelRunner:
...
@@ -245,7 +247,11 @@ class CPUModelRunner:
selected_token_indices
:
List
[
int
]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
...
@@ -262,10 +268,9 @@ class CPUModelRunner:
...
@@ -262,10 +268,9 @@ class CPUModelRunner:
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices
[
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
([
sampling_params
.
sampling_type
].
append
(
categorized_sample_indices_start_idx
,
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
categorized_sampled_token_indices_start_idx
))
])
categorized_sample_indices_start_idx
+=
1
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
...
@@ -328,7 +333,7 @@ class CPUModelRunner:
...
@@ -328,7 +333,7 @@ class CPUModelRunner:
def
prepare_input_tensors
(
def
prepare_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
]:
SamplingMetadata
]:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
...
@@ -381,7 +386,7 @@ class CPUModelRunner:
...
@@ -381,7 +386,7 @@ class CPUModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
...
...
vllm/worker/cpu_worker.py
View file @
533d2a1f
"""A CPU worker class."""
"""A CPU worker class."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -152,8 +152,8 @@ class CPUWorker(LoraNotSupportedWorkerBase):
...
@@ -152,8 +152,8 @@ class CPUWorker(LoraNotSupportedWorkerBase):
is_driver_worker
=
is_driver_worker
)
is_driver_worker
=
is_driver_worker
)
# Uninitialized cache engine. Will be initialized by
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
# initialize_cache.
self
.
cache_engine
=
No
ne
self
.
cache_engine
:
CPUCacheEngi
ne
self
.
cpu_cache
=
None
self
.
cpu_cache
:
List
[
torch
.
Tensor
]
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
self
.
init_distributed_environment
()
self
.
init_distributed_environment
()
...
@@ -257,13 +257,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
...
@@ -257,13 +257,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
num_seq_groups
:
int
=
len
(
seq_group_metadata_list
)
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
assert
blocks_to_copy
is
not
None
assert
len
(
blocks_to_swap_in
)
==
0
assert
len
(
blocks_to_swap_in
)
==
0
assert
len
(
blocks_to_swap_out
)
==
0
assert
len
(
blocks_to_swap_out
)
==
0
data
=
{
data
:
Dict
[
str
,
Any
]
=
{
"num_seq_groups"
:
num_seq_groups
,
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_copy"
:
blocks_to_copy
,
"blocks_to_copy"
:
blocks_to_copy
,
}
}
...
@@ -273,6 +273,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
...
@@ -273,6 +273,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups
=
data
[
"num_seq_groups"
]
num_seq_groups
=
data
[
"num_seq_groups"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
assert
blocks_to_copy
is
not
None
self
.
cache_copy
(
blocks_to_copy
)
self
.
cache_copy
(
blocks_to_copy
)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
...
...
vllm/worker/model_runner.py
View file @
533d2a1f
...
@@ -128,23 +128,17 @@ class ModelRunner:
...
@@ -128,23 +128,17 @@ class ModelRunner:
if
device_config
is
not
None
else
DeviceConfig
())
if
device_config
is
not
None
else
DeviceConfig
())
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
self
.
model
=
None
# Set after load_model.
self
.
block_size
=
None
# Set after initial profiling.
self
.
lora_manager
:
LRUCacheWorkerLoRAManager
=
None
self
.
lora_manager
=
None
self
.
graph_runners
:
Dict
[
int
,
CUDAGraphRunner
]
=
{}
self
.
graph_runners
:
Dict
[
int
,
CUDAGraphRunner
]
=
{}
self
.
graph_memory_pool
=
None
# Set during graph capture.
self
.
graph_memory_pool
:
Optional
[
Tuple
[
int
,
int
]]
=
None
# Set during graph capture.
self
.
max_context_len_to_capture
=
(
self
.
max_context_len_to_capture
=
(
self
.
model_config
.
max_context_len_to_capture
self
.
model_config
.
max_context_len_to_capture
if
self
.
model_config
is
not
None
else
0
)
if
self
.
model_config
is
not
None
else
0
)
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self
.
graph_block_tables
=
None
# Set after initial profiling.
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
vision_language_config
=
vision_language_config
self
.
vision_language_config
=
vision_language_config
...
@@ -152,6 +146,17 @@ class ModelRunner:
...
@@ -152,6 +146,17 @@ class ModelRunner:
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
# Lazy initialization
self
.
model
:
torch
.
nn
.
Module
# Set after load_model
self
.
block_size
:
int
# Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self
.
graph_block_tables
:
torch
.
Tensor
# Set after initial profiling.
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
with
CudaMemoryProfiler
()
as
m
:
with
CudaMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
self
.
model
=
get_model
(
...
@@ -489,16 +494,16 @@ class ModelRunner:
...
@@ -489,16 +494,16 @@ class ModelRunner:
lora_index_mapping
.
append
(
0
)
lora_index_mapping
.
append
(
0
)
batch_size
=
graph_batch_size
batch_size
=
graph_batch_size
context_lens
=
torch
.
tensor
(
context_lens
,
context_lens
_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
device
=
self
.
device
)
if
use_captured_graph
:
if
use_captured_graph
:
# When using cuda-graph all these tensors should be
# When using cuda-graph all these tensors should be
# padded.
# padded.
assert
context_lens
.
shape
[
0
]
==
len
(
input_tokens
)
assert
context_lens
_tensor
.
shape
[
0
]
==
len
(
input_tokens
)
assert
context_lens
.
shape
[
0
]
==
len
(
input_positions
)
assert
context_lens
_tensor
.
shape
[
0
]
==
len
(
input_positions
)
assert
context_lens
.
shape
[
0
]
==
len
(
slot_mapping
)
assert
context_lens
_tensor
.
shape
[
0
]
==
len
(
slot_mapping
)
# The shape of graph_block_tables is
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# [max batch size, max context len // block size].
...
@@ -527,7 +532,7 @@ class ModelRunner:
...
@@ -527,7 +532,7 @@ class ModelRunner:
max_prompt_len
=
None
,
max_prompt_len
=
None
,
subquery_start_loc
=
None
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens
,
context_lens
=
context_lens
_tensor
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
)
)
...
@@ -551,7 +556,11 @@ class ModelRunner:
...
@@ -551,7 +556,11 @@ class ModelRunner:
selected_token_indices
:
List
[
int
]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
...
@@ -569,10 +578,9 @@ class ModelRunner:
...
@@ -569,10 +578,9 @@ class ModelRunner:
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices
[
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
([
sampling_params
.
sampling_type
].
append
(
categorized_sample_indices_start_idx
,
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
categorized_sampled_token_indices_start_idx
))
])
categorized_sample_indices_start_idx
+=
1
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
...
@@ -596,6 +604,7 @@ class ModelRunner:
...
@@ -596,6 +604,7 @@ class ModelRunner:
categorized_sample_indices
[
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
sampling_params
.
sampling_type
].
extend
(
list
(
zip
(
zip
(
range
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
,
...
@@ -603,8 +612,8 @@ class ModelRunner:
...
@@ -603,8 +612,8 @@ class ModelRunner:
num_seqs
),
num_seqs
),
range
(
range
(
categorized_sampled_token_indices_start_idx
,
categorized_sampled_token_indices_start_idx
,
categorized_sampled_token_indices_start_idx
+
categorized_sampled_token_indices_start_idx
num_seqs
)))
+
num_seqs
)))
)
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
...
@@ -641,9 +650,9 @@ class ModelRunner:
...
@@ -641,9 +650,9 @@ class ModelRunner:
def
prepare_input_tensors
(
def
prepare_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
Set
[
in
t
],
LoRAMapping
,
torch
.
Tensor
]:
Set
[
LoRAReques
t
],
LoRAMapping
,
torch
.
Tensor
]:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
prefill_reqs
=
[]
prefill_reqs
=
[]
decode_reqs
=
[]
decode_reqs
=
[]
...
@@ -741,6 +750,7 @@ class ModelRunner:
...
@@ -741,6 +750,7 @@ class ModelRunner:
if
prefill_attn_metadata
is
not
None
:
if
prefill_attn_metadata
is
not
None
:
metadata_dict
.
update
(
prefill_attn_metadata
.
asdict_zerocopy
())
metadata_dict
.
update
(
prefill_attn_metadata
.
asdict_zerocopy
())
else
:
else
:
assert
decode_attn_metadata
is
not
None
metadata_dict
.
update
(
decode_attn_metadata
.
asdict_zerocopy
())
metadata_dict
.
update
(
decode_attn_metadata
.
asdict_zerocopy
())
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
...
@@ -809,7 +819,7 @@ class ModelRunner:
...
@@ -809,7 +819,7 @@ class ModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
...
@@ -923,7 +933,7 @@ class ModelRunner:
...
@@ -923,7 +933,7 @@ class ModelRunner:
raise
RuntimeError
(
"LoRA is not enabled."
)
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
remove_all_loras
()
return
self
.
lora_manager
.
remove_all_loras
()
def
set_active_loras
(
self
,
lora_requests
:
Lis
t
[
LoRARequest
],
def
set_active_loras
(
self
,
lora_requests
:
Se
t
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
lora_mapping
:
LoRAMapping
)
->
None
:
if
not
self
.
lora_manager
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
raise
RuntimeError
(
"LoRA is not enabled."
)
...
@@ -1065,10 +1075,16 @@ class CUDAGraphRunner:
...
@@ -1065,10 +1075,16 @@ class CUDAGraphRunner:
def
__init__
(
self
,
model
:
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
self
.
model
=
model
self
.
model
=
model
self
.
graph
=
None
self
.
input_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
input_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
output_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
output_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_graph
:
Optional
[
torch
.
cuda
.
CUDAGraph
]
=
None
@
property
def
graph
(
self
):
assert
self
.
_graph
is
not
None
return
self
.
_graph
def
capture
(
def
capture
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -1078,7 +1094,7 @@ class CUDAGraphRunner:
...
@@ -1078,7 +1094,7 @@ class CUDAGraphRunner:
memory_pool
,
memory_pool
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
assert
self
.
graph
is
None
assert
self
.
_
graph
is
None
# Run the model once without capturing the graph.
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# kernel launches for initial benchmarking (e.g., Triton autotune).
...
@@ -1095,8 +1111,8 @@ class CUDAGraphRunner:
...
@@ -1095,8 +1111,8 @@ class CUDAGraphRunner:
# Capture the graph.
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self
.
graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
_
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
torch
.
cuda
.
graph
(
self
.
_
graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
_maybe_pynccl
():
with
_maybe_pynccl
():
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
...
...
vllm/worker/neuron_model_runner.py
View file @
533d2a1f
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
vllm.config
import
(
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
...
@@ -34,9 +35,11 @@ class NeuronModelRunner:
...
@@ -34,9 +35,11 @@ class NeuronModelRunner:
self
.
device_config
=
(
device_config
self
.
device_config
=
(
device_config
if
device_config
is
not
None
else
DeviceConfig
())
if
device_config
is
not
None
else
DeviceConfig
())
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
self
.
model
=
None
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
# Lazy initialization.
self
.
model
:
nn
.
Module
# initialize after load_model.
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_neuron_model
(
self
.
model_config
,
self
.
model
=
get_neuron_model
(
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
parallel_config
=
self
.
parallel_config
,
...
@@ -147,7 +150,11 @@ class NeuronModelRunner:
...
@@ -147,7 +150,11 @@ class NeuronModelRunner:
selected_token_indices
:
List
[
int
]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
...
@@ -165,10 +172,9 @@ class NeuronModelRunner:
...
@@ -165,10 +172,9 @@ class NeuronModelRunner:
categorized_sample_indices_start_idx
+=
prompt_len
-
1
categorized_sample_indices_start_idx
+=
prompt_len
-
1
categorized_sample_indices
[
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
([
sampling_params
.
sampling_type
].
append
(
categorized_sample_indices_start_idx
,
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
categorized_sampled_token_indices_start_idx
))
])
categorized_sample_indices_start_idx
+=
1
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
...
@@ -237,7 +243,7 @@ class NeuronModelRunner:
...
@@ -237,7 +243,7 @@ class NeuronModelRunner:
def
prepare_input_tensors
(
def
prepare_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
SamplingMetadata
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
SamplingMetadata
]:
# NOTE: We assume that all sequences in the group are all prompts or
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# all decodes.
...
@@ -259,7 +265,7 @@ class NeuronModelRunner:
...
@@ -259,7 +265,7 @@ class NeuronModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
input_block_ids
,
sampling_metadata
(
input_tokens
,
input_positions
,
input_block_ids
,
sampling_metadata
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
...
...
vllm/worker/worker.py
View file @
533d2a1f
"""A GPU worker class."""
"""A GPU worker class."""
import
gc
import
gc
import
os
import
os
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -82,8 +82,8 @@ class Worker(WorkerBase):
...
@@ -82,8 +82,8 @@ class Worker(WorkerBase):
)
)
# Uninitialized cache engine. Will be initialized by
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
# initialize_cache.
self
.
cache_engine
=
No
ne
self
.
cache_engine
:
CacheEngi
ne
self
.
gpu_cache
=
None
self
.
gpu_cache
:
List
[
torch
.
Tensor
]
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
...
@@ -223,7 +223,7 @@ class Worker(WorkerBase):
...
@@ -223,7 +223,7 @@ class Worker(WorkerBase):
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
assert
blocks_to_copy
is
not
None
data
=
{
data
:
Dict
[
str
,
Any
]
=
{
"num_seq_groups"
:
num_seq_groups
,
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
...
@@ -237,6 +237,9 @@ class Worker(WorkerBase):
...
@@ -237,6 +237,9 @@ class Worker(WorkerBase):
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
self
.
cache_swap
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
self
.
cache_swap
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
...
...
vllm/worker/worker_base.py
View file @
533d2a1f
import
importlib
import
importlib
import
os
import
os
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -56,7 +56,7 @@ class WorkerBase(ABC):
...
@@ -56,7 +56,7 @@ class WorkerBase(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
get_cache_block_size_bytes
()
->
int
:
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size of a single cache block, in bytes. Used in
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
speculative decoding.
"""
"""
...
@@ -71,7 +71,7 @@ class WorkerBase(ABC):
...
@@ -71,7 +71,7 @@ class WorkerBase(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
list_loras
(
self
)
->
Lis
t
[
int
]:
def
list_loras
(
self
)
->
Se
t
[
int
]:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -86,7 +86,7 @@ class LoraNotSupportedWorkerBase(WorkerBase):
...
@@ -86,7 +86,7 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
def
list_loras
(
self
)
->
Lis
t
[
int
]:
def
list_loras
(
self
)
->
Se
t
[
int
]:
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment