Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f48954a4
Commit
f48954a4
authored
Jun 12, 2024
by
zhuwenwen
Browse files
merge v0.5.0
parents
1dba29d3
8f89d720
Changes
253
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
322 additions
and
214 deletions
+322
-214
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+1
-1
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+9
-6
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+5
-28
vllm/spec_decode/proposer_worker_base.py
vllm/spec_decode/proposer_worker_base.py
+44
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+10
-9
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+21
-30
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+3
-8
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+17
-9
vllm/transformers_utils/image_processor.py
vllm/transformers_utils/image_processor.py
+45
-0
vllm/utils.py
vllm/utils.py
+16
-12
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+37
-20
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+5
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+109
-86
No files found.
vllm/spec_decode/interfaces.py
View file @
f48954a4
...
...
@@ -55,7 +55,7 @@ class SpeculativeScores:
class
SpeculativeProposer
(
ABC
):
@
abstractmethod
def
get_proposals
(
def
get_
spec_
proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
...
...
vllm/spec_decode/multi_step_worker.py
View file @
f48954a4
...
...
@@ -7,11 +7,12 @@ import torch
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
class
MultiStepWorker
(
Worker
):
class
MultiStepWorker
(
Worker
,
ProposerWorkerBase
):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
...
...
@@ -33,7 +34,7 @@ class MultiStepWorker(Worker):
super
().
init_device
()
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
...
...
@@ -92,11 +93,12 @@ class MultiStepWorker(Worker):
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
return
self
.
_proposer
.
get_
spec_
proposals
(
execute_model_req
)
@
staticmethod
def
_append_new_tokens
(
self
,
model_output
:
SamplerOutput
,
seq_group_metadata_list
:
SequenceGroupMetadata
)
->
None
:
model_output
:
List
[
SamplerOutput
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
None
:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
...
...
@@ -116,8 +118,9 @@ class MultiStepWorker(Worker):
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
)
seq
.
update_num_computed_tokens
(
1
)
@
staticmethod
def
_shallow_copy_inputs
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
List
[
SequenceGroupMetadata
]:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
...
...
vllm/spec_decode/ngram_worker.py
View file @
f48954a4
...
...
@@ -5,15 +5,16 @@ import torch
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
class
NGramWorker
(
LoraNotSupportedWorkerBase
):
class
NGramWorker
(
NonLLMProposerWorkerBase
,
LoraNotSupportedWorkerBase
):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding,
and in future we may also do RAG type drafter and other scen
e
rios
and in future we may also do RAG type drafter and other scen
a
rios
which don't rely on LLM model to give proposals.
"""
...
...
@@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
# Current only support Top1Proposer
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
device
=
self
.
device
,
vocab_size
=
self
.
vocab_size
,
)
def
set_include_gpu_probs_tensor
(
self
):
# NGram don't need gpu sampler
pass
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
None
:
"""NGram doesn't depend on model execution, just pass this function"""
pass
def
determine_num_available_blocks
(
self
)
->
None
:
"""NGram doesn't depend on model execution, no need to check blocks"""
pass
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""As there is no cache need to handle, just pass this function"""
pass
def
get_cache_block_size_bytes
(
self
):
"""Return the size of a cache block in bytes."""
return
0
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
...
...
@@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase):
-
1
,
):
ngram_tensor
=
input_ids
[
-
ngram_size
:]
proposal_start_idx
=
None
if
ngram_size
==
1
:
# Do not match itself and do not use unfold and all
matches
=
(
input_ids
[:
-
1
]
==
ngram_tensor
)
...
...
@@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
return
self
.
_proposer
.
get_
spec_
proposals
(
execute_model_req
)
def
_raise_if_unsupported
(
self
,
...
...
vllm/spec_decode/proposer_worker_base.py
0 → 100644
View file @
f48954a4
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Tuple
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposer
from
vllm.worker.worker_base
import
WorkerBase
class
ProposerWorkerBase
(
WorkerBase
,
SpeculativeProposer
):
"""Interface for proposer workers"""
@
abstractmethod
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
raise
NotImplementedError
def
set_include_gpu_probs_tensor
(
self
):
"""Implementation optional"""
pass
class
NonLLMProposerWorkerBase
(
ProposerWorkerBase
,
ABC
):
"""Proposer worker which does not use a model with kvcache"""
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
"""get_spec_proposals is used to get the proposals"""
return
[]
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""This is never called on the proposer, only the target model"""
raise
NotImplementedError
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
pass
def
get_cache_block_size_bytes
(
self
)
->
int
:
return
0
vllm/spec_decode/spec_decode_worker.py
View file @
f48954a4
...
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
from
vllm.config
import
SpeculativeConfig
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
...
...
@@ -14,6 +15,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
get_all_num_logprobs
,
get_all_seq_ids
,
get_sampled_token_logprobs
,
nvtx_range
,
...
...
@@ -29,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
assert
"speculative_config"
in
kwargs
speculative_config
=
kwargs
.
get
(
"speculative_config"
)
speculative_config
:
SpeculativeConfig
=
kwargs
.
get
(
"speculative_config"
)
assert
speculative_config
is
not
None
target_worker
=
Worker
(
*
args
,
**
kwargs
)
...
...
@@ -108,16 +110,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger
.
info
(
"Configuring SpecDecodeWorker with proposer=%s"
,
type
(
proposer_worker
))
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
disable_by_batch_size
=
disable_by_batch_size
,
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
,
))
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
disable_by_batch_size
=
disable_by_batch_size
,
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
))
def
__init__
(
self
,
proposer_worker
:
WorkerBase
,
proposer_worker
:
Proposer
WorkerBase
,
scorer_worker
:
WorkerBase
,
rejection_sampler
:
RejectionSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
...
...
@@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a
# communication to inform the
n
.
# communication to inform the
m
.
broadcast_dict
=
dict
(
num_lookahead_slots
=
num_lookahead_slots
,
disable_all_speculation
=
disable_all_speculation
,
...
...
vllm/spec_decode/top1_proposer.py
View file @
f48954a4
...
...
@@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.worker.worker_base
import
WorkerBase
class
Top1Proposer
(
SpeculativeProposer
):
...
...
@@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer):
def
__init__
(
self
,
worker
:
WorkerBase
,
worker
:
Proposer
WorkerBase
,
device
:
str
,
vocab_size
:
int
,
max_proposal_len
:
Optional
[
int
]
=
None
,
...
...
@@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer):
self
.
max_proposal_len
=
max_proposal_len
self
.
_vocab_size
=
vocab_size
def
get_proposals
(
def
get_
spec_
proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
...
...
@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_indices
,
)
def
_remove_no_proposal_seqs
(
self
,
proposal_lens
,
maybe_sampler_output
,
@
staticmethod
def
_remove_no_proposal_seqs
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
...
...
@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
SamplerOutput
],
maybe_sampler_output
:
Optional
[
List
[
SamplerOutput
]
]
,
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
...
...
@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
full
(
size
=
(
batch_size
,
proposal_len
,
),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
,
)
proposal_probs
=
torch
.
zeros
(
batch_size
,
proposal_len
,
self
.
_vocab_size
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
,
)
proposal_lens_tensor
=
torch
.
zeros
(
len
(
proposal_lens
),
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_tokens
=
torch
.
tensor
(
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
)
proposal_probs
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
,
self
.
_vocab_size
)
proposal_lens_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
len
(
proposal_lens
))
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
...
...
@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
torch
.
full
(
entire_proposal_tokens
=
proposal_tokens
.
new_
full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
torch
.
zeros
(
entire_proposal_probs
=
proposal_probs
.
new_
zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
dtype
=
torch
.
float32
,
device
=
self
.
_device
,
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
...
...
vllm/spec_decode/util.py
View file @
f48954a4
from
contextlib
import
contextmanager
from
itertools
import
chain
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceOutput
)
SeqId
=
int
...
...
@@ -16,11 +15,7 @@ def get_all_seq_ids(
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return
list
(
chain
.
from_iterable
([
seq_group_metadata
.
seq_data
.
keys
()
for
seq_group_metadata
in
seq_group_metadata_list
]))
return
[
seq_id
for
sg
in
seq_group_metadata_list
for
seq_id
in
sg
.
seq_data
]
def
get_all_num_logprobs
(
...
...
@@ -68,7 +63,7 @@ def create_sequence_group_output(
seq_id
:
SeqId
,
topk_token_ids
:
List
[
int
],
topk_logprobs
:
List
[
float
],
)
->
SequenceGroupOutput
:
)
->
Completion
SequenceGroupOutput
:
"""Create a SequenceGroupOutput given the sampling results.
Args:
...
...
vllm/transformers_utils/config.py
View file @
f48954a4
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
,
Type
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
JAISConfig
,
MPTConfig
,
RWConfig
)
logger
=
init_logger
(
__name__
)
_CONFIG_REGISTRY
:
Dict
[
str
,
PretrainedConfig
]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]
]
=
{
"chatglm"
:
ChatGLMConfig
,
"dbrx"
:
DbrxConfig
,
"mpt"
:
MPTConfig
,
...
...
@@ -22,8 +23,13 @@ def get_config(model: str,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
)
->
PretrainedConfig
:
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
)
->
PretrainedConfig
:
try
:
if
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
else
:
from
transformers
import
AutoConfig
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
...
...
@@ -45,10 +51,12 @@ def get_config(model: str,
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
,
code_revision
=
code_revision
)
if
rope_scaling
is
not
None
:
logger
.
info
(
"Updating rope_scaling from %r to %r"
,
getattr
(
config
,
"rope_scaling"
,
None
),
rope_scaling
)
config
.
update
({
"rope_scaling"
:
rope_scaling
})
for
key
,
value
in
[(
"rope_scaling"
,
rope_scaling
),
(
"rope_theta"
,
rope_theta
)]:
if
value
is
not
None
:
logger
.
info
(
"Updating %s from %r to %r"
,
key
,
getattr
(
config
,
key
,
None
),
value
)
config
.
update
({
key
:
value
})
return
config
...
...
@@ -63,4 +71,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert
hasattr
(
config
.
text_config
,
"num_attention_heads"
)
return
config
.
text_config
else
:
return
config
\ No newline at end of file
return
config
vllm/transformers_utils/image_processor.py
0 → 100644
View file @
f48954a4
from
functools
import
lru_cache
from
typing
import
Optional
from
transformers
import
AutoImageProcessor
from
transformers.image_processing_utils
import
BaseImageProcessor
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
get_image_processor
(
processor_name
:
str
,
*
args
,
trust_remote_code
:
bool
=
False
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
BaseImageProcessor
:
"""Gets an image processor for the given model name via HuggingFace."""
try
:
processor
:
BaseImageProcessor
=
AutoImageProcessor
.
from_pretrained
(
processor_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
except
ValueError
as
e
:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
if
not
trust_remote_code
:
err_msg
=
(
"Failed to load the image processor. If the image processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
return
processor
cached_get_image_processor
=
lru_cache
(
get_image_processor
)
vllm/utils.py
View file @
f48954a4
...
...
@@ -17,10 +17,12 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable
,
List
,
Optional
,
OrderedDict
,
Tuple
,
TypeVar
,
Union
)
import
numpy
as
np
import
psutil
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
enable_trace_function_call
,
init_logger
T
=
TypeVar
(
"T"
)
...
...
@@ -147,12 +149,8 @@ def is_neuron() -> bool:
@
lru_cache
(
maxsize
=
None
)
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
# NOTE: This import statement should be executed lazily since
# the Neuron-X backend does not have the `cuda_utils` module.
from
vllm._C
import
cuda_utils
max_shared_mem
=
(
cuda_util
s
.
get_max_shared_memory_per_block_device_attribute
(
gpu
))
op
s
.
get_max_shared_memory_per_block_device_attribute
(
gpu
))
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail
assert
max_shared_mem
>
0
,
"max_shared_mem can not be zero"
...
...
@@ -288,7 +286,15 @@ def get_distributed_init_method(ip: str, port: int) -> str:
def
get_open_port
()
->
int
:
port
=
envs
.
VLLM_PORT
if
port
is
not
None
:
return
port
while
True
:
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
port
))
return
port
except
OSError
:
port
+=
1
# Increment port number if already in use
logger
.
info
(
"Port %d is already in use, trying port %d"
,
port
-
1
,
port
)
# try ipv4
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
...
...
@@ -501,11 +507,6 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f
"(e.g., 1, 2, 3). Given input:
{
s
}
"
)
from
e
def
pad_to_max_length
(
x
:
List
[
int
],
max_len
:
int
,
pad
:
int
)
->
List
[
int
]:
assert
len
(
x
)
<=
max_len
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
def
make_tensor_with_pad
(
x
:
List
[
List
[
int
]],
max_len
:
int
,
...
...
@@ -518,7 +519,10 @@ def make_tensor_with_pad(
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x
=
[
pad_to_max_length
(
x_i
,
max_len
,
pad
)
for
x_i
in
x
]
padded_x
=
np
.
zeros
([
len
(
x
),
max_len
],
dtype
=
np
.
int32
)
+
pad
for
ind
,
blocktb
in
enumerate
(
x
):
assert
len
(
blocktb
)
<=
max_len
padded_x
[
ind
,
:
len
(
blocktb
)]
=
blocktb
return
torch
.
tensor
(
padded_x
,
dtype
=
dtype
,
device
=
device
)
...
...
vllm/worker/cpu_model_runner.py
View file @
f48954a4
from
typing
import
List
,
Optional
,
Tuple
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -11,6 +12,7 @@ from vllm.distributed import broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
...
...
@@ -63,6 +65,16 @@ class CPUModelRunner:
self
.
block_size
,
)
# Create processor for multi-modal data
if
self
.
vision_language_config
is
not
None
:
self
.
multi_modal_input_processor
=
MULTIMODAL_REGISTRY
\
.
create_input_processor
(
self
.
model_config
,
self
.
vision_language_config
,
)
else
:
self
.
multi_modal_input_processor
=
None
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
...
...
@@ -80,14 +92,15 @@ class CPUModelRunner:
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
Dict
[
str
,
torch
.
Tensor
]]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
multi_modal_kwargs_list
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
defaultdict
(
list
)
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
...
...
@@ -108,9 +121,17 @@ class CPUModelRunner:
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq_len
)))
if
seq_group_metadata
.
multi_modal_data
:
multi_modal_input_list
.
append
(
seq_group_metadata
.
multi_modal_data
.
data
)
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
is
not
None
:
# Process multi-modal data
if
self
.
multi_modal_input_processor
is
None
:
raise
ValueError
(
"Multi-modal inputs are only supported by "
"vision language models."
)
mm_kwargs
=
self
.
multi_modal_input_processor
(
mm_data
)
for
k
,
v
in
mm_kwargs
.
items
():
multi_modal_kwargs_list
[
k
].
append
(
v
)
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
...
...
@@ -134,14 +155,10 @@ class CPUModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
multi_modal_input_list
:
assert
self
.
vision_language_config
,
(
"Multi-modal inputs are only supported by "
"vision language models."
)
multi_modal_input
=
torch
.
cat
(
multi_modal_input_list
,
dim
=
0
).
to
(
self
.
device
)
else
:
multi_modal_input
=
None
multi_modal_kwargs
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
).
to
(
self
.
device
)
for
k
,
v
in
multi_modal_kwargs_list
.
items
()
}
num_prompt_tokens
=
len
(
input_tokens
)
...
...
@@ -167,7 +184,7 @@ class CPUModelRunner:
slot_mapping
=
slot_mapping
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_
input
)
multi_modal_
kwargs
)
def
_prepare_decode
(
self
,
...
...
@@ -257,8 +274,8 @@ class CPUModelRunner:
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
Optional
[
torch
.
Tensor
]]:
multi_modal_
input
=
None
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
]
:
multi_modal_
kwargs
=
None
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
...
...
@@ -266,7 +283,7 @@ class CPUModelRunner:
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_
input
multi_modal_
kwargs
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
...
...
@@ -307,7 +324,7 @@ class CPUModelRunner:
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
multi_modal_
input
)
sampling_metadata
,
multi_modal_
kwargs
)
@
torch
.
inference_mode
()
def
execute_model
(
...
...
vllm/worker/embedding_model_runner.py
View file @
f48954a4
...
...
@@ -90,7 +90,7 @@ class EmbeddingModelRunner(ModelRunner):
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
PoolingMetadata
,
Set
[
LoRARequest
],
LoRAMapping
,
torch
.
Tensor
]:
Set
[
LoRARequest
],
LoRAMapping
,
Dict
[
str
,
torch
.
Tensor
]
]
:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
# Prepare input tensors.
...
...
@@ -102,7 +102,7 @@ class EmbeddingModelRunner(ModelRunner):
_
,
lora_mapping
,
lora_requests
,
multi_modal_
input
,
multi_modal_
kwargs
,
slot_mapping
,
num_prefill_tokens
,
num_decode_tokens
,
...
...
@@ -117,7 +117,7 @@ class EmbeddingModelRunner(ModelRunner):
"input_positions"
:
input_positions
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"multi_modal_
input
"
:
multi_modal_
input
,
"multi_modal_
kwargs
"
:
multi_modal_
kwargs
,
"num_prefill_tokens"
:
num_prefill_tokens
,
"num_decode_tokens"
:
num_decode_tokens
,
"slot_mapping"
:
slot_mapping
,
...
...
@@ -132,7 +132,7 @@ class EmbeddingModelRunner(ModelRunner):
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
multi_modal_
input
=
metadata_dict
.
pop
(
"multi_modal_
input
"
)
multi_modal_
kwargs
=
metadata_dict
.
pop
(
"multi_modal_
kwargs
"
)
if
metadata_dict
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
...
...
@@ -143,7 +143,7 @@ class EmbeddingModelRunner(ModelRunner):
prompt_lens
=
None
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
pooling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_
input
)
lora_requests
,
lora_mapping
,
multi_modal_
kwargs
)
def
_prepare_pooling
(
self
,
...
...
vllm/worker/model_runner.py
View file @
f48954a4
import
gc
import
time
import
warnings
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -18,9 +20,9 @@ from vllm.lora.request import LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
(
CudaMemoryProfiler
,
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
,
make_tensor_with_pad
)
...
...
@@ -34,6 +36,7 @@ _BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
33
)
]
_NUM_WARMUP_ITERS
=
2
class
ModelInput
(
NamedTuple
):
...
...
@@ -44,7 +47,7 @@ class ModelInput(NamedTuple):
query_lens
:
List
[
int
]
lora_mapping
:
Optional
[
LoRAMapping
]
lora_requests
:
Set
[
LoRARequest
]
multi_modal_
input
:
Optional
[
torch
.
Tensor
]
multi_modal_
kwargs
:
Dict
[
str
,
torch
.
Tensor
]
slot_mapping
:
torch
.
Tensor
num_prefill_tokens
:
int
num_decode_tokens
:
int
...
...
@@ -60,7 +63,7 @@ class ModelInput(NamedTuple):
query_lens
=
[],
lora_mapping
=
None
,
lora_requests
=
set
(),
multi_modal_
input
=
None
,
multi_modal_
kwargs
=
{}
,
slot_mapping
=
torch
.
empty
(
0
,
device
=
device
),
num_prefill_tokens
=
0
,
num_decode_tokens
=
0
,
...
...
@@ -122,6 +125,16 @@ class ModelRunner:
self
.
block_size
,
)
# Create processor for multi-modal data
if
self
.
vision_language_config
is
not
None
:
self
.
multi_modal_input_processor
=
MULTIMODAL_REGISTRY
\
.
create_input_processor
(
self
.
model_config
,
self
.
vision_language_config
,
)
else
:
self
.
multi_modal_input_processor
=
None
# Lazy initialization
self
.
model
:
nn
.
Module
# Set after load_model
# Set if the backend is flashinfer.
...
...
@@ -242,7 +255,8 @@ class ModelRunner:
context_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
multi_modal_kwargs_list
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
defaultdict
(
list
)
decode_only
=
True
num_prefills
=
0
num_prefill_tokens
=
0
...
...
@@ -415,11 +429,19 @@ class ModelRunner:
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
is
not
None
else
1
))
if
seq_group_metadata
.
multi_modal_data
:
multi_modal_input_list
.
append
(
seq_group_metadata
.
multi_modal_data
.
data
)
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
is
not
None
:
# Process multi-modal data
if
self
.
multi_modal_input_processor
is
None
:
raise
ValueError
(
"Multi-modal inputs are only supported by "
"vision language models."
)
mm_kwargs
=
self
.
multi_modal_input_processor
(
mm_data
)
for
k
,
v
in
mm_kwargs
.
items
():
multi_modal_kwargs_list
[
k
].
append
(
v
)
if
_is_block_tables_empty
(
seq_group_metadata
.
block_tables
):
# During memory profiling, the block tables are not
...
...
@@ -505,26 +527,6 @@ class ModelRunner:
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
if
multi_modal_input_list
:
assert
self
.
vision_language_config
,
(
"Multi-modal inputs are only supported by "
"vision language models."
)
multi_modal_input
=
torch
.
cat
(
multi_modal_input_list
,
dim
=
0
).
to
(
self
.
device
)
else
:
multi_modal_input
=
None
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
...
...
@@ -532,11 +534,6 @@ class ModelRunner:
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
...
...
@@ -589,6 +586,21 @@ class ModelRunner:
seq_start_loc
=
seq_start_loc
,
data_type
=
kv_cache_dtype
)
else
:
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
...
...
@@ -614,6 +626,11 @@ class ModelRunner:
else
:
lora_mapping
=
None
multi_modal_kwargs
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
).
to
(
self
.
device
)
for
k
,
v
in
multi_modal_kwargs_list
.
items
()
}
return
ModelInput
(
input_tokens
=
input_tokens_tensor
,
input_positions
=
input_positions_tensor
,
...
...
@@ -622,7 +639,7 @@ class ModelRunner:
query_lens
=
query_lens
,
lora_mapping
=
lora_mapping
,
lora_requests
=
lora_requests
,
multi_modal_
input
=
multi_modal_
input
,
multi_modal_
kwargs
=
multi_modal_
kwargs
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
...
...
@@ -633,7 +650,7 @@ class ModelRunner:
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
Set
[
LoRARequest
],
LoRAMapping
,
torch
.
Tensor
]:
Set
[
LoRARequest
],
LoRAMapping
,
Dict
[
str
,
torch
.
Tensor
]
]
:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
# Prepare input tensors.
...
...
@@ -645,7 +662,7 @@ class ModelRunner:
query_lens
,
lora_mapping
,
lora_requests
,
multi_modal_
input
,
multi_modal_
kwargs
,
slot_mapping
,
num_prefill_tokens
,
num_decode_tokens
,
...
...
@@ -662,7 +679,7 @@ class ModelRunner:
sampling_metadata
.
selected_token_indices
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"multi_modal_
input
"
:
multi_modal_
input
,
"multi_modal_
kwargs
"
:
multi_modal_
kwargs
,
"num_prefill_tokens"
:
num_prefill_tokens
,
"num_decode_tokens"
:
num_decode_tokens
,
"slot_mapping"
:
slot_mapping
,
...
...
@@ -679,7 +696,7 @@ class ModelRunner:
"selected_token_indices"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
multi_modal_
input
=
metadata_dict
.
pop
(
"multi_modal_
input
"
)
multi_modal_
kwargs
=
metadata_dict
.
pop
(
"multi_modal_
kwargs
"
)
if
metadata_dict
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
...
...
@@ -694,7 +711,7 @@ class ModelRunner:
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_
input
)
multi_modal_
kwargs
)
@
torch
.
inference_mode
()
def
execute_model
(
...
...
@@ -703,7 +720,7 @@ class ModelRunner:
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_
input
lora_requests
,
lora_mapping
,
multi_modal_
kwargs
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
if
self
.
lora_config
:
...
...
@@ -717,15 +734,14 @@ class ModelRunner:
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
else
:
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
}
if
self
.
vision_language_config
:
execute_model_kwargs
.
update
({
"image_input"
:
multi_modal_input
})
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
hidden_states
=
model_executable
(
input_ids
=
input_tokens
,
positions
=
input_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
**
multi_modal_kwargs
,
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
...
...
@@ -781,16 +797,24 @@ class ModelRunner:
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
if
self
.
vision_language_config
:
model_config
=
self
.
model_config
vlm_config
=
self
.
vision_language_config
if
vlm_config
:
max_num_seqs
=
min
(
max_num_seqs
,
int
(
max_num_batched_tokens
/
self
.
vision_language_config
.
image_feature_size
))
int
(
max_num_batched_tokens
/
vlm_config
.
image_feature_size
))
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
seq_data
,
fake_multi_modal_input
=
_prepare_fake_inputs
(
seq_len
,
self
.
vision_language_config
)
if
vlm_config
is
None
:
seq_data
=
SequenceData
([
0
]
*
seq_len
)
dummy_multi_modal_data
=
None
else
:
seq_data
,
dummy_multi_modal_data
=
MULTIMODAL_REGISTRY
\
.
dummy_data_for_profiling
(
seq_len
,
model_config
,
vlm_config
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
...
...
@@ -799,7 +823,7 @@ class ModelRunner:
block_tables
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
fake
_multi_modal_
input
,
multi_modal_data
=
dummy
_multi_modal_
data
,
)
seqs
.
append
(
seq
)
...
...
@@ -871,6 +895,10 @@ class ModelRunner:
seq_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
# Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture.
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
graph_batch_size
=
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
batch_size_capture_list
=
[
...
...
@@ -907,9 +935,11 @@ class ModelRunner:
self
.
set_active_loras
(
set
(),
lora_mapping
)
graph_runner
=
CUDAGraphRunner
(
self
.
model
)
graph_runner
.
capture
(
hidden_states
=
graph_runner
.
capture
(
input_tokens
[:
batch_size
],
input_positions
[:
batch_size
],
hidden_states
[:
batch_size
]
if
hidden_states
is
not
None
else
None
,
kv_caches
,
attn_metadata
,
memory_pool
=
self
.
graph_memory_pool
,
...
...
@@ -946,35 +976,46 @@ class CUDAGraphRunner:
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
Optional
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
memory_pool
:
Optional
[
Tuple
[
int
,
int
]],
stream
:
torch
.
cuda
.
Stream
,
**
kwargs
,
)
->
None
:
)
->
torch
.
Tensor
:
assert
self
.
_graph
is
None
# Run the model
once
without capturing the graph.
# Run the model
a few times
without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
**
kwargs
,
)
# Note one iteration is not enough for torch.jit.script
for
_
in
range
(
_NUM_WARMUP_ITERS
):
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
**
kwargs
,
)
torch
.
cuda
.
synchronize
()
# Capture the graph.
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
,
stream
=
stream
):
hidden_states
=
self
.
model
(
output_
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
**
kwargs
,
)
if
hidden_states
is
not
None
:
hidden_states
.
copy_
(
output_hidden_states
)
else
:
hidden_states
=
output_hidden_states
del
output_hidden_states
# make sure `output_hidden_states` is deleted
# in the graph's memory pool
gc
.
collect
()
torch
.
cuda
.
synchronize
()
# Save the input and output buffers.
...
...
@@ -987,7 +1028,7 @@ class CUDAGraphRunner:
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
return
return
hidden_states
def
forward
(
self
,
...
...
@@ -1034,24 +1075,6 @@ def _get_graph_batch_size(batch_size: int) -> int:
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
def
_prepare_fake_inputs
(
seq_len
:
int
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]):
"""Prepare fake inputs for profile run."""
if
vision_language_config
:
prompt_tokens
=
[
vision_language_config
.
image_token_id
]
*
vision_language_config
.
image_feature_size
+
[
0
]
*
(
seq_len
-
vision_language_config
.
image_feature_size
)
fake_image_input
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
torch
.
zeros
(
vision_language_config
.
image_input_shape
,
dtype
=
torch
.
float16
))
else
:
prompt_tokens
=
[
0
]
*
seq_len
fake_image_input
=
None
return
SequenceData
(
prompt_tokens
),
fake_image_input
def
_is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
"""
Check if block_tables is None or a dictionary with all None values.
...
...
Prev
1
…
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment