Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f48954a4
Commit
f48954a4
authored
Jun 12, 2024
by
zhuwenwen
Browse files
merge v0.5.0
parents
1dba29d3
8f89d720
Changes
253
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
322 additions
and
214 deletions
+322
-214
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+1
-1
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+9
-6
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+5
-28
vllm/spec_decode/proposer_worker_base.py
vllm/spec_decode/proposer_worker_base.py
+44
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+10
-9
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+21
-30
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+3
-8
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+17
-9
vllm/transformers_utils/image_processor.py
vllm/transformers_utils/image_processor.py
+45
-0
vllm/utils.py
vllm/utils.py
+16
-12
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+37
-20
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+5
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+109
-86
No files found.
vllm/spec_decode/interfaces.py
View file @
f48954a4
...
@@ -55,7 +55,7 @@ class SpeculativeScores:
...
@@ -55,7 +55,7 @@ class SpeculativeScores:
class
SpeculativeProposer
(
ABC
):
class
SpeculativeProposer
(
ABC
):
@
abstractmethod
@
abstractmethod
def
get_proposals
(
def
get_
spec_
proposals
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
)
->
SpeculativeProposals
:
...
...
vllm/spec_decode/multi_step_worker.py
View file @
f48954a4
...
@@ -7,11 +7,12 @@ import torch
...
@@ -7,11 +7,12 @@ import torch
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
class
MultiStepWorker
(
Worker
):
class
MultiStepWorker
(
Worker
,
ProposerWorkerBase
):
"""The MultiStepWorker is equivalent to a Worker except that it allows
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
allocated enough space to store the additional KV. This reduces overhead
...
@@ -33,7 +34,7 @@ class MultiStepWorker(Worker):
...
@@ -33,7 +34,7 @@ class MultiStepWorker(Worker):
super
().
init_device
()
super
().
init_device
()
self
.
_proposer
=
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
device
,
self
.
vocab_size
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
max_proposal_len
=
self
.
max_model_len
,
...
@@ -92,11 +93,12 @@ class MultiStepWorker(Worker):
...
@@ -92,11 +93,12 @@ class MultiStepWorker(Worker):
speculative tokens per sequence is determined by max_proposal_len.
speculative tokens per sequence is determined by max_proposal_len.
"""
"""
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
return
self
.
_proposer
.
get_
spec_
proposals
(
execute_model_req
)
@
staticmethod
def
_append_new_tokens
(
def
_append_new_tokens
(
self
,
model_output
:
SamplerOutput
,
model_output
:
List
[
SamplerOutput
]
,
seq_group_metadata_list
:
SequenceGroupMetadata
)
->
None
:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
None
:
"""Given model output from a single run, append the tokens to the
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
required if the worker is to perform multiple forward passes.
...
@@ -116,8 +118,9 @@ class MultiStepWorker(Worker):
...
@@ -116,8 +118,9 @@ class MultiStepWorker(Worker):
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
)
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
)
seq
.
update_num_computed_tokens
(
1
)
seq
.
update_num_computed_tokens
(
1
)
@
staticmethod
def
_shallow_copy_inputs
(
def
_shallow_copy_inputs
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
List
[
SequenceGroupMetadata
]:
)
->
List
[
SequenceGroupMetadata
]:
"""Copy input data structures to remove side-effects when input data
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
structures are shared with other modules.
...
...
vllm/spec_decode/ngram_worker.py
View file @
f48954a4
...
@@ -5,15 +5,16 @@ import torch
...
@@ -5,15 +5,16 @@ import torch
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
class
NGramWorker
(
LoraNotSupportedWorkerBase
):
class
NGramWorker
(
NonLLMProposerWorkerBase
,
LoraNotSupportedWorkerBase
):
"""NGramWorker provides a light drafter without need for model.
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding,
Current NGramWorker only implement prompt lookup decoding,
and in future we may also do RAG type drafter and other scen
e
rios
and in future we may also do RAG type drafter and other scen
a
rios
which don't rely on LLM model to give proposals.
which don't rely on LLM model to give proposals.
"""
"""
...
@@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
# Current only support Top1Proposer
# Current only support Top1Proposer
self
.
_proposer
=
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
device
=
self
.
device
,
device
=
self
.
device
,
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
)
)
def
set_include_gpu_probs_tensor
(
self
):
# NGram don't need gpu sampler
pass
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
None
:
"""NGram doesn't depend on model execution, just pass this function"""
pass
def
determine_num_available_blocks
(
self
)
->
None
:
"""NGram doesn't depend on model execution, no need to check blocks"""
pass
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""As there is no cache need to handle, just pass this function"""
pass
def
get_cache_block_size_bytes
(
self
):
"""Return the size of a cache block in bytes."""
return
0
def
sampler_output
(
def
sampler_output
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
...
@@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase):
-
1
,
-
1
,
):
):
ngram_tensor
=
input_ids
[
-
ngram_size
:]
ngram_tensor
=
input_ids
[
-
ngram_size
:]
proposal_start_idx
=
None
if
ngram_size
==
1
:
if
ngram_size
==
1
:
# Do not match itself and do not use unfold and all
# Do not match itself and do not use unfold and all
matches
=
(
input_ids
[:
-
1
]
==
ngram_tensor
)
matches
=
(
input_ids
[:
-
1
]
==
ngram_tensor
)
...
@@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
speculative tokens per sequence is determined by max_proposal_len.
speculative tokens per sequence is determined by max_proposal_len.
"""
"""
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
return
self
.
_proposer
.
get_
spec_
proposals
(
execute_model_req
)
def
_raise_if_unsupported
(
def
_raise_if_unsupported
(
self
,
self
,
...
...
vllm/spec_decode/proposer_worker_base.py
0 → 100644
View file @
f48954a4
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Tuple
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposer
from
vllm.worker.worker_base
import
WorkerBase
class
ProposerWorkerBase
(
WorkerBase
,
SpeculativeProposer
):
"""Interface for proposer workers"""
@
abstractmethod
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
raise
NotImplementedError
def
set_include_gpu_probs_tensor
(
self
):
"""Implementation optional"""
pass
class
NonLLMProposerWorkerBase
(
ProposerWorkerBase
,
ABC
):
"""Proposer worker which does not use a model with kvcache"""
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
"""get_spec_proposals is used to get the proposals"""
return
[]
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""This is never called on the proposer, only the target model"""
raise
NotImplementedError
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
pass
def
get_cache_block_size_bytes
(
self
)
->
int
:
return
0
vllm/spec_decode/spec_decode_worker.py
View file @
f48954a4
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
import
torch
from
vllm.config
import
SpeculativeConfig
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
...
@@ -14,6 +15,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
...
@@ -14,6 +15,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
get_all_num_logprobs
,
get_all_seq_ids
,
get_all_num_logprobs
,
get_all_seq_ids
,
get_sampled_token_logprobs
,
nvtx_range
,
get_sampled_token_logprobs
,
nvtx_range
,
...
@@ -29,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -29,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
"""
assert
"speculative_config"
in
kwargs
assert
"speculative_config"
in
kwargs
speculative_config
=
kwargs
.
get
(
"speculative_config"
)
speculative_config
:
SpeculativeConfig
=
kwargs
.
get
(
"speculative_config"
)
assert
speculative_config
is
not
None
assert
speculative_config
is
not
None
target_worker
=
Worker
(
*
args
,
**
kwargs
)
target_worker
=
Worker
(
*
args
,
**
kwargs
)
...
@@ -108,16 +110,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -108,16 +110,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger
.
info
(
"Configuring SpecDecodeWorker with proposer=%s"
,
logger
.
info
(
"Configuring SpecDecodeWorker with proposer=%s"
,
type
(
proposer_worker
))
type
(
proposer_worker
))
return
SpecDecodeWorker
(
return
SpecDecodeWorker
(
proposer_worker
,
proposer_worker
,
scorer_worker
,
scorer_worker
,
disable_by_batch_size
=
disable_by_batch_size
,
disable_by_batch_size
=
disable_by_batch_size
,
rejection_sampler
=
RejectionSampler
(
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
))
disable_bonus_tokens
=
disable_bonus_tokens
,
))
def
__init__
(
def
__init__
(
self
,
self
,
proposer_worker
:
WorkerBase
,
proposer_worker
:
Proposer
WorkerBase
,
scorer_worker
:
WorkerBase
,
scorer_worker
:
WorkerBase
,
rejection_sampler
:
RejectionSampler
,
rejection_sampler
:
RejectionSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
...
@@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# This is required as if the number of draft model runs changes
# This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a
# dynamically, the non-driver workers won't know unless we perform a
# communication to inform the
n
.
# communication to inform the
m
.
broadcast_dict
=
dict
(
broadcast_dict
=
dict
(
num_lookahead_slots
=
num_lookahead_slots
,
num_lookahead_slots
=
num_lookahead_slots
,
disable_all_speculation
=
disable_all_speculation
,
disable_all_speculation
=
disable_all_speculation
,
...
...
vllm/spec_decode/top1_proposer.py
View file @
f48954a4
...
@@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
...
@@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.worker.worker_base
import
WorkerBase
class
Top1Proposer
(
SpeculativeProposer
):
class
Top1Proposer
(
SpeculativeProposer
):
...
@@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer):
def
__init__
(
def
__init__
(
self
,
self
,
worker
:
WorkerBase
,
worker
:
Proposer
WorkerBase
,
device
:
str
,
device
:
str
,
vocab_size
:
int
,
vocab_size
:
int
,
max_proposal_len
:
Optional
[
int
]
=
None
,
max_proposal_len
:
Optional
[
int
]
=
None
,
...
@@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer):
self
.
max_proposal_len
=
max_proposal_len
self
.
max_proposal_len
=
max_proposal_len
self
.
_vocab_size
=
vocab_size
self
.
_vocab_size
=
vocab_size
def
get_proposals
(
def
get_
spec_
proposals
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
)
->
SpeculativeProposals
:
...
@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_indices
,
nonzero_proposal_len_indices
,
)
)
def
_remove_no_proposal_seqs
(
self
,
proposal_lens
,
maybe_sampler_output
,
@
staticmethod
def
_remove_no_proposal_seqs
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
):
nonzero_proposal_len_indices
,
transposed
):
"""Remove sequences from nonzero_proposal_len_indices and reset
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
their proposal_len to 0 the draft worker does not provide a proposal
...
@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
self
,
self
,
batch_size
:
int
,
batch_size
:
int
,
proposal_len
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
SamplerOutput
],
maybe_sampler_output
:
Optional
[
List
[
SamplerOutput
]
]
,
proposal_lens
:
List
[
int
],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
sampler_transposed
:
bool
,
...
@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
if
maybe_sampler_output
is
None
:
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
full
(
proposal_tokens
=
torch
.
tensor
(
-
1
,
size
=
(
dtype
=
torch
.
long
,
batch_size
,
device
=
self
.
_device
).
expand
(
proposal_len
,
batch_size
,
proposal_len
)
),
proposal_probs
=
torch
.
tensor
(
0
,
fill_value
=-
1
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
device
=
self
.
_device
,
batch_size
,
proposal_len
,
)
self
.
_vocab_size
)
proposal_probs
=
torch
.
zeros
(
proposal_lens_tensor
=
torch
.
tensor
(
0
,
batch_size
,
dtype
=
torch
.
long
,
proposal_len
,
device
=
self
.
_device
).
expand
(
self
.
_vocab_size
,
len
(
proposal_lens
))
dtype
=
torch
.
float32
,
device
=
self
.
_device
,
)
proposal_lens_tensor
=
torch
.
zeros
(
len
(
proposal_lens
),
dtype
=
torch
.
long
,
device
=
self
.
_device
)
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
sampler_output
=
maybe_sampler_output
...
@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
# Now, reformat the output GPU tensors such that each sequence has
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
torch
.
full
(
entire_proposal_tokens
=
proposal_tokens
.
new_
full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
,
)
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
torch
.
zeros
(
entire_proposal_probs
=
proposal_probs
.
new_
zeros
(
batch_size
,
batch_size
,
*
proposal_probs
.
shape
[
1
:],
*
proposal_probs
.
shape
[
1
:],
dtype
=
torch
.
float32
,
device
=
self
.
_device
,
)
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
...
...
vllm/spec_decode/util.py
View file @
f48954a4
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
itertools
import
chain
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceOutput
)
SeqId
=
int
SeqId
=
int
...
@@ -16,11 +15,7 @@ def get_all_seq_ids(
...
@@ -16,11 +15,7 @@ def get_all_seq_ids(
"""Given a list of SequenceGroupMetadata, create a list of all
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
sequence ids.
"""
"""
return
list
(
return
[
seq_id
for
sg
in
seq_group_metadata_list
for
seq_id
in
sg
.
seq_data
]
chain
.
from_iterable
([
seq_group_metadata
.
seq_data
.
keys
()
for
seq_group_metadata
in
seq_group_metadata_list
]))
def
get_all_num_logprobs
(
def
get_all_num_logprobs
(
...
@@ -68,7 +63,7 @@ def create_sequence_group_output(
...
@@ -68,7 +63,7 @@ def create_sequence_group_output(
seq_id
:
SeqId
,
seq_id
:
SeqId
,
topk_token_ids
:
List
[
int
],
topk_token_ids
:
List
[
int
],
topk_logprobs
:
List
[
float
],
topk_logprobs
:
List
[
float
],
)
->
SequenceGroupOutput
:
)
->
Completion
SequenceGroupOutput
:
"""Create a SequenceGroupOutput given the sampling results.
"""Create a SequenceGroupOutput given the sampling results.
Args:
Args:
...
...
vllm/transformers_utils/config.py
View file @
f48954a4
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
,
Type
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
JAISConfig
,
MPTConfig
,
RWConfig
)
JAISConfig
,
MPTConfig
,
RWConfig
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_CONFIG_REGISTRY
:
Dict
[
str
,
PretrainedConfig
]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]
]
=
{
"chatglm"
:
ChatGLMConfig
,
"chatglm"
:
ChatGLMConfig
,
"dbrx"
:
DbrxConfig
,
"dbrx"
:
DbrxConfig
,
"mpt"
:
MPTConfig
,
"mpt"
:
MPTConfig
,
...
@@ -22,8 +23,13 @@ def get_config(model: str,
...
@@ -22,8 +23,13 @@ def get_config(model: str,
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
)
->
PretrainedConfig
:
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
)
->
PretrainedConfig
:
try
:
try
:
if
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
else
:
from
transformers
import
AutoConfig
config
=
AutoConfig
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
model
,
model
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
...
@@ -45,10 +51,12 @@ def get_config(model: str,
...
@@ -45,10 +51,12 @@ def get_config(model: str,
config
=
config_class
.
from_pretrained
(
model
,
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
,
revision
=
revision
,
code_revision
=
code_revision
)
code_revision
=
code_revision
)
if
rope_scaling
is
not
None
:
for
key
,
value
in
[(
"rope_scaling"
,
rope_scaling
),
logger
.
info
(
"Updating rope_scaling from %r to %r"
,
(
"rope_theta"
,
rope_theta
)]:
getattr
(
config
,
"rope_scaling"
,
None
),
rope_scaling
)
if
value
is
not
None
:
config
.
update
({
"rope_scaling"
:
rope_scaling
})
logger
.
info
(
"Updating %s from %r to %r"
,
key
,
getattr
(
config
,
key
,
None
),
value
)
config
.
update
({
key
:
value
})
return
config
return
config
...
@@ -63,4 +71,4 @@ def get_hf_text_config(config: PretrainedConfig):
...
@@ -63,4 +71,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert
hasattr
(
config
.
text_config
,
"num_attention_heads"
)
assert
hasattr
(
config
.
text_config
,
"num_attention_heads"
)
return
config
.
text_config
return
config
.
text_config
else
:
else
:
return
config
return
config
\ No newline at end of file
vllm/transformers_utils/image_processor.py
0 → 100644
View file @
f48954a4
from
functools
import
lru_cache
from
typing
import
Optional
from
transformers
import
AutoImageProcessor
from
transformers.image_processing_utils
import
BaseImageProcessor
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
get_image_processor
(
processor_name
:
str
,
*
args
,
trust_remote_code
:
bool
=
False
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
BaseImageProcessor
:
"""Gets an image processor for the given model name via HuggingFace."""
try
:
processor
:
BaseImageProcessor
=
AutoImageProcessor
.
from_pretrained
(
processor_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
except
ValueError
as
e
:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
if
not
trust_remote_code
:
err_msg
=
(
"Failed to load the image processor. If the image processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
return
processor
cached_get_image_processor
=
lru_cache
(
get_image_processor
)
vllm/utils.py
View file @
f48954a4
...
@@ -17,10 +17,12 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
...
@@ -17,10 +17,12 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable
,
List
,
Optional
,
OrderedDict
,
Tuple
,
TypeVar
,
Hashable
,
List
,
Optional
,
OrderedDict
,
Tuple
,
TypeVar
,
Union
)
Union
)
import
numpy
as
np
import
psutil
import
psutil
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.logger
import
enable_trace_function_call
,
init_logger
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
...
@@ -147,12 +149,8 @@ def is_neuron() -> bool:
...
@@ -147,12 +149,8 @@ def is_neuron() -> bool:
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
"""Returns the maximum shared memory per thread block in bytes."""
# NOTE: This import statement should be executed lazily since
# the Neuron-X backend does not have the `cuda_utils` module.
from
vllm._C
import
cuda_utils
max_shared_mem
=
(
max_shared_mem
=
(
cuda_util
s
.
get_max_shared_memory_per_block_device_attribute
(
gpu
))
op
s
.
get_max_shared_memory_per_block_device_attribute
(
gpu
))
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail
# will fail
assert
max_shared_mem
>
0
,
"max_shared_mem can not be zero"
assert
max_shared_mem
>
0
,
"max_shared_mem can not be zero"
...
@@ -288,7 +286,15 @@ def get_distributed_init_method(ip: str, port: int) -> str:
...
@@ -288,7 +286,15 @@ def get_distributed_init_method(ip: str, port: int) -> str:
def
get_open_port
()
->
int
:
def
get_open_port
()
->
int
:
port
=
envs
.
VLLM_PORT
port
=
envs
.
VLLM_PORT
if
port
is
not
None
:
if
port
is
not
None
:
return
port
while
True
:
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
port
))
return
port
except
OSError
:
port
+=
1
# Increment port number if already in use
logger
.
info
(
"Port %d is already in use, trying port %d"
,
port
-
1
,
port
)
# try ipv4
# try ipv4
try
:
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
...
@@ -501,11 +507,6 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
...
@@ -501,11 +507,6 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f
"(e.g., 1, 2, 3). Given input:
{
s
}
"
)
from
e
f
"(e.g., 1, 2, 3). Given input:
{
s
}
"
)
from
e
def
pad_to_max_length
(
x
:
List
[
int
],
max_len
:
int
,
pad
:
int
)
->
List
[
int
]:
assert
len
(
x
)
<=
max_len
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
def
make_tensor_with_pad
(
def
make_tensor_with_pad
(
x
:
List
[
List
[
int
]],
x
:
List
[
List
[
int
]],
max_len
:
int
,
max_len
:
int
,
...
@@ -518,7 +519,10 @@ def make_tensor_with_pad(
...
@@ -518,7 +519,10 @@ def make_tensor_with_pad(
The padding is applied to the end of each inner list until it reaches
The padding is applied to the end of each inner list until it reaches
`max_len`.
`max_len`.
"""
"""
padded_x
=
[
pad_to_max_length
(
x_i
,
max_len
,
pad
)
for
x_i
in
x
]
padded_x
=
np
.
zeros
([
len
(
x
),
max_len
],
dtype
=
np
.
int32
)
+
pad
for
ind
,
blocktb
in
enumerate
(
x
):
assert
len
(
blocktb
)
<=
max_len
padded_x
[
ind
,
:
len
(
blocktb
)]
=
blocktb
return
torch
.
tensor
(
padded_x
,
dtype
=
dtype
,
device
=
device
)
return
torch
.
tensor
(
padded_x
,
dtype
=
dtype
,
device
=
device
)
...
...
vllm/worker/cpu_model_runner.py
View file @
f48954a4
from
typing
import
List
,
Optional
,
Tuple
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -11,6 +12,7 @@ from vllm.distributed import broadcast_tensor_dict
...
@@ -11,6 +12,7 @@ from vllm.distributed import broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
make_tensor_with_pad
...
@@ -63,6 +65,16 @@ class CPUModelRunner:
...
@@ -63,6 +65,16 @@ class CPUModelRunner:
self
.
block_size
,
self
.
block_size
,
)
)
# Create processor for multi-modal data
if
self
.
vision_language_config
is
not
None
:
self
.
multi_modal_input_processor
=
MULTIMODAL_REGISTRY
\
.
create_input_processor
(
self
.
model_config
,
self
.
vision_language_config
,
)
else
:
self
.
multi_modal_input_processor
=
None
# Lazy initialization.
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
model
:
nn
.
Module
# Set after init_Model
...
@@ -80,14 +92,15 @@ class CPUModelRunner:
...
@@ -80,14 +92,15 @@ class CPUModelRunner:
def
_prepare_prompt
(
def
_prepare_prompt
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
Dict
[
Optional
[
torch
.
Tensor
]]:
str
,
torch
.
Tensor
]]:
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
multi_modal_kwargs_list
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
defaultdict
(
list
)
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
is_prompt
...
@@ -108,9 +121,17 @@ class CPUModelRunner:
...
@@ -108,9 +121,17 @@ class CPUModelRunner:
# is always the first token in the sequence.
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq_len
)))
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq_len
)))
if
seq_group_metadata
.
multi_modal_data
:
mm_data
=
seq_group_metadata
.
multi_modal_data
multi_modal_input_list
.
append
(
if
mm_data
is
not
None
:
seq_group_metadata
.
multi_modal_data
.
data
)
# Process multi-modal data
if
self
.
multi_modal_input_processor
is
None
:
raise
ValueError
(
"Multi-modal inputs are only supported by "
"vision language models."
)
mm_kwargs
=
self
.
multi_modal_input_processor
(
mm_data
)
for
k
,
v
in
mm_kwargs
.
items
():
multi_modal_kwargs_list
[
k
].
append
(
v
)
# Compute the slot mapping.
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
...
@@ -134,14 +155,10 @@ class CPUModelRunner:
...
@@ -134,14 +155,10 @@ class CPUModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
if
multi_modal_input_list
:
multi_modal_kwargs
=
{
assert
self
.
vision_language_config
,
(
k
:
torch
.
cat
(
v
,
dim
=
0
).
to
(
self
.
device
)
"Multi-modal inputs are only supported by "
for
k
,
v
in
multi_modal_kwargs_list
.
items
()
"vision language models."
)
}
multi_modal_input
=
torch
.
cat
(
multi_modal_input_list
,
dim
=
0
).
to
(
self
.
device
)
else
:
multi_modal_input
=
None
num_prompt_tokens
=
len
(
input_tokens
)
num_prompt_tokens
=
len
(
input_tokens
)
...
@@ -167,7 +184,7 @@ class CPUModelRunner:
...
@@ -167,7 +184,7 @@ class CPUModelRunner:
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
)
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_
input
)
multi_modal_
kwargs
)
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
...
@@ -257,8 +274,8 @@ class CPUModelRunner:
...
@@ -257,8 +274,8 @@ class CPUModelRunner:
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
Optional
[
torch
.
Tensor
]]:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
]
:
multi_modal_
input
=
None
multi_modal_
kwargs
=
None
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# all decodes.
...
@@ -266,7 +283,7 @@ class CPUModelRunner:
...
@@ -266,7 +283,7 @@ class CPUModelRunner:
# Prepare input tensors.
# Prepare input tensors.
if
is_prompt
:
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_
input
multi_modal_
kwargs
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
else
:
(
input_tokens
,
input_positions
,
(
input_tokens
,
input_positions
,
...
@@ -307,7 +324,7 @@ class CPUModelRunner:
...
@@ -307,7 +324,7 @@ class CPUModelRunner:
)
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
multi_modal_
input
)
sampling_metadata
,
multi_modal_
kwargs
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
...
...
vllm/worker/embedding_model_runner.py
View file @
f48954a4
...
@@ -90,7 +90,7 @@ class EmbeddingModelRunner(ModelRunner):
...
@@ -90,7 +90,7 @@ class EmbeddingModelRunner(ModelRunner):
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
PoolingMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
PoolingMetadata
,
Set
[
LoRARequest
],
LoRAMapping
,
torch
.
Tensor
]:
Set
[
LoRARequest
],
LoRAMapping
,
Dict
[
str
,
torch
.
Tensor
]
]
:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
# Prepare input tensors.
# Prepare input tensors.
...
@@ -102,7 +102,7 @@ class EmbeddingModelRunner(ModelRunner):
...
@@ -102,7 +102,7 @@ class EmbeddingModelRunner(ModelRunner):
_
,
_
,
lora_mapping
,
lora_mapping
,
lora_requests
,
lora_requests
,
multi_modal_
input
,
multi_modal_
kwargs
,
slot_mapping
,
slot_mapping
,
num_prefill_tokens
,
num_prefill_tokens
,
num_decode_tokens
,
num_decode_tokens
,
...
@@ -117,7 +117,7 @@ class EmbeddingModelRunner(ModelRunner):
...
@@ -117,7 +117,7 @@ class EmbeddingModelRunner(ModelRunner):
"input_positions"
:
input_positions
,
"input_positions"
:
input_positions
,
"lora_requests"
:
lora_requests
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"lora_mapping"
:
lora_mapping
,
"multi_modal_
input
"
:
multi_modal_
input
,
"multi_modal_
kwargs
"
:
multi_modal_
kwargs
,
"num_prefill_tokens"
:
num_prefill_tokens
,
"num_prefill_tokens"
:
num_prefill_tokens
,
"num_decode_tokens"
:
num_decode_tokens
,
"num_decode_tokens"
:
num_decode_tokens
,
"slot_mapping"
:
slot_mapping
,
"slot_mapping"
:
slot_mapping
,
...
@@ -132,7 +132,7 @@ class EmbeddingModelRunner(ModelRunner):
...
@@ -132,7 +132,7 @@ class EmbeddingModelRunner(ModelRunner):
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
multi_modal_
input
=
metadata_dict
.
pop
(
"multi_modal_
input
"
)
multi_modal_
kwargs
=
metadata_dict
.
pop
(
"multi_modal_
kwargs
"
)
if
metadata_dict
:
if
metadata_dict
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
**
metadata_dict
)
...
@@ -143,7 +143,7 @@ class EmbeddingModelRunner(ModelRunner):
...
@@ -143,7 +143,7 @@ class EmbeddingModelRunner(ModelRunner):
prompt_lens
=
None
)
prompt_lens
=
None
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
pooling_metadata
,
return
(
input_tokens
,
input_positions
,
attn_metadata
,
pooling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_
input
)
lora_requests
,
lora_mapping
,
multi_modal_
kwargs
)
def
_prepare_pooling
(
def
_prepare_pooling
(
self
,
self
,
...
...
vllm/worker/model_runner.py
View file @
f48954a4
import
gc
import
time
import
time
import
warnings
import
warnings
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -18,9 +20,9 @@ from vllm.lora.request import LoRARequest
...
@@ -18,9 +20,9 @@ from vllm.lora.request import LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
SequenceData
,
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
get_kv_cache_torch_dtype
,
is_hip
,
from
vllm.utils
import
(
CudaMemoryProfiler
,
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
,
make_tensor_with_pad
)
is_pin_memory_available
,
make_tensor_with_pad
)
...
@@ -34,6 +36,7 @@ _BATCH_SIZE_ALIGNMENT = 8
...
@@ -34,6 +36,7 @@ _BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
33
)
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
33
)
]
]
_NUM_WARMUP_ITERS
=
2
class
ModelInput
(
NamedTuple
):
class
ModelInput
(
NamedTuple
):
...
@@ -44,7 +47,7 @@ class ModelInput(NamedTuple):
...
@@ -44,7 +47,7 @@ class ModelInput(NamedTuple):
query_lens
:
List
[
int
]
query_lens
:
List
[
int
]
lora_mapping
:
Optional
[
LoRAMapping
]
lora_mapping
:
Optional
[
LoRAMapping
]
lora_requests
:
Set
[
LoRARequest
]
lora_requests
:
Set
[
LoRARequest
]
multi_modal_
input
:
Optional
[
torch
.
Tensor
]
multi_modal_
kwargs
:
Dict
[
str
,
torch
.
Tensor
]
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
num_prefill_tokens
:
int
num_prefill_tokens
:
int
num_decode_tokens
:
int
num_decode_tokens
:
int
...
@@ -60,7 +63,7 @@ class ModelInput(NamedTuple):
...
@@ -60,7 +63,7 @@ class ModelInput(NamedTuple):
query_lens
=
[],
query_lens
=
[],
lora_mapping
=
None
,
lora_mapping
=
None
,
lora_requests
=
set
(),
lora_requests
=
set
(),
multi_modal_
input
=
None
,
multi_modal_
kwargs
=
{}
,
slot_mapping
=
torch
.
empty
(
0
,
device
=
device
),
slot_mapping
=
torch
.
empty
(
0
,
device
=
device
),
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
...
@@ -122,6 +125,16 @@ class ModelRunner:
...
@@ -122,6 +125,16 @@ class ModelRunner:
self
.
block_size
,
self
.
block_size
,
)
)
# Create processor for multi-modal data
if
self
.
vision_language_config
is
not
None
:
self
.
multi_modal_input_processor
=
MULTIMODAL_REGISTRY
\
.
create_input_processor
(
self
.
model_config
,
self
.
vision_language_config
,
)
else
:
self
.
multi_modal_input_processor
=
None
# Lazy initialization
# Lazy initialization
self
.
model
:
nn
.
Module
# Set after load_model
self
.
model
:
nn
.
Module
# Set after load_model
# Set if the backend is flashinfer.
# Set if the backend is flashinfer.
...
@@ -242,7 +255,8 @@ class ModelRunner:
...
@@ -242,7 +255,8 @@ class ModelRunner:
context_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
multi_modal_kwargs_list
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
defaultdict
(
list
)
decode_only
=
True
decode_only
=
True
num_prefills
=
0
num_prefills
=
0
num_prefill_tokens
=
0
num_prefill_tokens
=
0
...
@@ -415,11 +429,19 @@ class ModelRunner:
...
@@ -415,11 +429,19 @@ class ModelRunner:
[
lora_id
]
*
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
is
not
None
else
1
))
if
seq_group_metadata
.
multi_modal_data
:
mm_data
=
seq_group_metadata
.
multi_modal_data
multi_modal_input_list
.
append
(
if
mm_data
is
not
None
:
seq_group_metadata
.
multi_modal_data
.
data
)
# Process multi-modal data
if
self
.
multi_modal_input_processor
is
None
:
raise
ValueError
(
"Multi-modal inputs are only supported by "
"vision language models."
)
mm_kwargs
=
self
.
multi_modal_input_processor
(
mm_data
)
for
k
,
v
in
mm_kwargs
.
items
():
multi_modal_kwargs_list
[
k
].
append
(
v
)
if
_is_block_tables_empty
(
seq_group_metadata
.
block_tables
):
if
_is_block_tables_empty
(
seq_group_metadata
.
block_tables
):
# During memory profiling, the block tables are not
# During memory profiling, the block tables are not
...
@@ -505,26 +527,6 @@ class ModelRunner:
...
@@ -505,26 +527,6 @@ class ModelRunner:
)
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
if
multi_modal_input_list
:
assert
self
.
vision_language_config
,
(
"Multi-modal inputs are only supported by "
"vision language models."
)
multi_modal_input
=
torch
.
cat
(
multi_modal_input_list
,
dim
=
0
).
to
(
self
.
device
)
else
:
multi_modal_input
=
None
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
device
=
self
.
device
)
...
@@ -532,11 +534,6 @@ class ModelRunner:
...
@@ -532,11 +534,6 @@ class ModelRunner:
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
self
.
device
)
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
torch
.
cumsum
(
seq_lens_tensor
,
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
dtype
=
seq_start_loc
.
dtype
,
...
@@ -589,6 +586,21 @@ class ModelRunner:
...
@@ -589,6 +586,21 @@ class ModelRunner:
seq_start_loc
=
seq_start_loc
,
seq_start_loc
=
seq_start_loc
,
data_type
=
kv_cache_dtype
)
data_type
=
kv_cache_dtype
)
else
:
else
:
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
...
@@ -614,6 +626,11 @@ class ModelRunner:
...
@@ -614,6 +626,11 @@ class ModelRunner:
else
:
else
:
lora_mapping
=
None
lora_mapping
=
None
multi_modal_kwargs
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
).
to
(
self
.
device
)
for
k
,
v
in
multi_modal_kwargs_list
.
items
()
}
return
ModelInput
(
return
ModelInput
(
input_tokens
=
input_tokens_tensor
,
input_tokens
=
input_tokens_tensor
,
input_positions
=
input_positions_tensor
,
input_positions
=
input_positions_tensor
,
...
@@ -622,7 +639,7 @@ class ModelRunner:
...
@@ -622,7 +639,7 @@ class ModelRunner:
query_lens
=
query_lens
,
query_lens
=
query_lens
,
lora_mapping
=
lora_mapping
,
lora_mapping
=
lora_mapping
,
lora_requests
=
lora_requests
,
lora_requests
=
lora_requests
,
multi_modal_
input
=
multi_modal_
input
,
multi_modal_
kwargs
=
multi_modal_
kwargs
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
num_prefill_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
...
@@ -633,7 +650,7 @@ class ModelRunner:
...
@@ -633,7 +650,7 @@ class ModelRunner:
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
Set
[
LoRARequest
],
LoRAMapping
,
torch
.
Tensor
]:
Set
[
LoRARequest
],
LoRAMapping
,
Dict
[
str
,
torch
.
Tensor
]
]
:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
# Prepare input tensors.
# Prepare input tensors.
...
@@ -645,7 +662,7 @@ class ModelRunner:
...
@@ -645,7 +662,7 @@ class ModelRunner:
query_lens
,
query_lens
,
lora_mapping
,
lora_mapping
,
lora_requests
,
lora_requests
,
multi_modal_
input
,
multi_modal_
kwargs
,
slot_mapping
,
slot_mapping
,
num_prefill_tokens
,
num_prefill_tokens
,
num_decode_tokens
,
num_decode_tokens
,
...
@@ -662,7 +679,7 @@ class ModelRunner:
...
@@ -662,7 +679,7 @@ class ModelRunner:
sampling_metadata
.
selected_token_indices
,
sampling_metadata
.
selected_token_indices
,
"lora_requests"
:
lora_requests
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"lora_mapping"
:
lora_mapping
,
"multi_modal_
input
"
:
multi_modal_
input
,
"multi_modal_
kwargs
"
:
multi_modal_
kwargs
,
"num_prefill_tokens"
:
num_prefill_tokens
,
"num_prefill_tokens"
:
num_prefill_tokens
,
"num_decode_tokens"
:
num_decode_tokens
,
"num_decode_tokens"
:
num_decode_tokens
,
"slot_mapping"
:
slot_mapping
,
"slot_mapping"
:
slot_mapping
,
...
@@ -679,7 +696,7 @@ class ModelRunner:
...
@@ -679,7 +696,7 @@ class ModelRunner:
"selected_token_indices"
)
"selected_token_indices"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
multi_modal_
input
=
metadata_dict
.
pop
(
"multi_modal_
input
"
)
multi_modal_
kwargs
=
metadata_dict
.
pop
(
"multi_modal_
kwargs
"
)
if
metadata_dict
:
if
metadata_dict
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
**
metadata_dict
)
...
@@ -694,7 +711,7 @@ class ModelRunner:
...
@@ -694,7 +711,7 @@ class ModelRunner:
return
(
input_tokens
,
input_positions
,
attn_metadata
,
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
,
sampling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_
input
)
multi_modal_
kwargs
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
...
@@ -703,7 +720,7 @@ class ModelRunner:
...
@@ -703,7 +720,7 @@ class ModelRunner:
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_
input
lora_requests
,
lora_mapping
,
multi_modal_
kwargs
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
if
self
.
lora_config
:
if
self
.
lora_config
:
...
@@ -717,15 +734,14 @@ class ModelRunner:
...
@@ -717,15 +734,14 @@ class ModelRunner:
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
else
:
else
:
model_executable
=
self
.
model
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
hidden_states
=
model_executable
(
"positions"
:
input_positions
,
input_ids
=
input_tokens
,
"kv_caches"
:
kv_caches
,
positions
=
input_positions
,
"attn_metadata"
:
attn_metadata
,
kv_caches
=
kv_caches
,
}
attn_metadata
=
attn_metadata
,
if
self
.
vision_language_config
:
**
multi_modal_kwargs
,
execute_model_kwargs
.
update
({
"image_input"
:
multi_modal_input
})
)
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
...
@@ -781,16 +797,24 @@ class ModelRunner:
...
@@ -781,16 +797,24 @@ class ModelRunner:
# To exercise the worst scenario for GPU memory consumption,
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
# of images processed.
if
self
.
vision_language_config
:
model_config
=
self
.
model_config
vlm_config
=
self
.
vision_language_config
if
vlm_config
:
max_num_seqs
=
min
(
max_num_seqs
=
min
(
max_num_seqs
,
max_num_seqs
,
int
(
max_num_batched_tokens
/
int
(
max_num_batched_tokens
/
vlm_config
.
image_feature_size
))
self
.
vision_language_config
.
image_feature_size
))
for
group_id
in
range
(
max_num_seqs
):
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
seq_data
,
fake_multi_modal_input
=
_prepare_fake_inputs
(
seq_len
,
self
.
vision_language_config
)
if
vlm_config
is
None
:
seq_data
=
SequenceData
([
0
]
*
seq_len
)
dummy_multi_modal_data
=
None
else
:
seq_data
,
dummy_multi_modal_data
=
MULTIMODAL_REGISTRY
\
.
dummy_data_for_profiling
(
seq_len
,
model_config
,
vlm_config
)
seq
=
SequenceGroupMetadata
(
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
request_id
=
str
(
group_id
),
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -799,7 +823,7 @@ class ModelRunner:
...
@@ -799,7 +823,7 @@ class ModelRunner:
block_tables
=
None
,
block_tables
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
fake
_multi_modal_
input
,
multi_modal_data
=
dummy
_multi_modal_
data
,
)
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
...
@@ -871,6 +895,10 @@ class ModelRunner:
...
@@ -871,6 +895,10 @@ class ModelRunner:
seq_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
seq_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
# Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture.
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
graph_batch_size
=
_get_graph_batch_size
(
graph_batch_size
=
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
self
.
scheduler_config
.
max_num_seqs
)
batch_size_capture_list
=
[
batch_size_capture_list
=
[
...
@@ -907,9 +935,11 @@ class ModelRunner:
...
@@ -907,9 +935,11 @@ class ModelRunner:
self
.
set_active_loras
(
set
(),
lora_mapping
)
self
.
set_active_loras
(
set
(),
lora_mapping
)
graph_runner
=
CUDAGraphRunner
(
self
.
model
)
graph_runner
=
CUDAGraphRunner
(
self
.
model
)
graph_runner
.
capture
(
hidden_states
=
graph_runner
.
capture
(
input_tokens
[:
batch_size
],
input_tokens
[:
batch_size
],
input_positions
[:
batch_size
],
input_positions
[:
batch_size
],
hidden_states
[:
batch_size
]
if
hidden_states
is
not
None
else
None
,
kv_caches
,
kv_caches
,
attn_metadata
,
attn_metadata
,
memory_pool
=
self
.
graph_memory_pool
,
memory_pool
=
self
.
graph_memory_pool
,
...
@@ -946,35 +976,46 @@ class CUDAGraphRunner:
...
@@ -946,35 +976,46 @@ class CUDAGraphRunner:
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
Optional
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
memory_pool
:
Optional
[
Tuple
[
int
,
int
]],
memory_pool
:
Optional
[
Tuple
[
int
,
int
]],
stream
:
torch
.
cuda
.
Stream
,
stream
:
torch
.
cuda
.
Stream
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
torch
.
Tensor
:
assert
self
.
_graph
is
None
assert
self
.
_graph
is
None
# Run the model
once
without capturing the graph.
# Run the model
a few times
without capturing the graph.
# This is to make sure that the captured graph does not include the
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# kernel launches for initial benchmarking (e.g., Triton autotune).
self
.
model
(
# Note one iteration is not enough for torch.jit.script
input_ids
,
for
_
in
range
(
_NUM_WARMUP_ITERS
):
positions
,
self
.
model
(
kv_caches
,
input_ids
,
attn_metadata
,
positions
,
**
kwargs
,
kv_caches
,
)
attn_metadata
,
**
kwargs
,
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Capture the graph.
# Capture the graph.
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
,
stream
=
stream
):
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
,
stream
=
stream
):
hidden_states
=
self
.
model
(
output_
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
kv_caches
,
kv_caches
,
attn_metadata
,
attn_metadata
,
**
kwargs
,
**
kwargs
,
)
)
if
hidden_states
is
not
None
:
hidden_states
.
copy_
(
output_hidden_states
)
else
:
hidden_states
=
output_hidden_states
del
output_hidden_states
# make sure `output_hidden_states` is deleted
# in the graph's memory pool
gc
.
collect
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Save the input and output buffers.
# Save the input and output buffers.
...
@@ -987,7 +1028,7 @@ class CUDAGraphRunner:
...
@@ -987,7 +1028,7 @@ class CUDAGraphRunner:
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
return
return
hidden_states
def
forward
(
def
forward
(
self
,
self
,
...
@@ -1034,24 +1075,6 @@ def _get_graph_batch_size(batch_size: int) -> int:
...
@@ -1034,24 +1075,6 @@ def _get_graph_batch_size(batch_size: int) -> int:
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
def
_prepare_fake_inputs
(
seq_len
:
int
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]):
"""Prepare fake inputs for profile run."""
if
vision_language_config
:
prompt_tokens
=
[
vision_language_config
.
image_token_id
]
*
vision_language_config
.
image_feature_size
+
[
0
]
*
(
seq_len
-
vision_language_config
.
image_feature_size
)
fake_image_input
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
torch
.
zeros
(
vision_language_config
.
image_input_shape
,
dtype
=
torch
.
float16
))
else
:
prompt_tokens
=
[
0
]
*
seq_len
fake_image_input
=
None
return
SequenceData
(
prompt_tokens
),
fake_image_input
def
_is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
def
_is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
"""
"""
Check if block_tables is None or a dictionary with all None values.
Check if block_tables is None or a dictionary with all None values.
...
...
Prev
1
…
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment