Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9b902f9e
Commit
9b902f9e
authored
Sep 11, 2024
by
zhuwenwen
Browse files
fix run error
parent
a48d654d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
271 additions
and
142 deletions
+271
-142
vllm/attention/selector.py
vllm/attention/selector.py
+6
-6
vllm/benchmark_throughput.py
vllm/benchmark_throughput.py
+132
-7
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+129
-121
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+1
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-6
No files found.
vllm/attention/selector.py
View file @
9b902f9e
...
@@ -202,12 +202,12 @@ def which_attn_to_use(
...
@@ -202,12 +202,12 @@ def which_attn_to_use(
# AMD GPUs.
# AMD GPUs.
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
#
if selected_backend == _Backend.ROCM_FLASH:
if
current_platform
.
get_device_capability
()[
0
]
!=
9
:
#
if current_platform.get_device_capability()[0] != 9:
# not Instinct series GPUs.
#
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
#
logger.info("flash_attn is not supported on NAVI GPUs.")
else
:
#
else:
logger
.
info
(
"%s is not supported in AMD GPUs."
,
selected_backend
)
#
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return
_Backend
.
ROCM_FLASH
return
_Backend
.
ROCM_FLASH
# FlashAttn in NVIDIA GPUs.
# FlashAttn in NVIDIA GPUs.
...
...
vllm/benchmark_throughput.py
View file @
9b902f9e
...
@@ -7,14 +7,16 @@ from typing import List, Optional, Tuple
...
@@ -7,14 +7,16 @@ from typing import List, Optional, Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
uvloop
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
(
AutoModelForCausalLM
,
AutoTokenizer
,
from
transformers
import
(
AutoModelForCausalLM
,
AutoTokenizer
,
PreTrainedTokenizerBase
)
PreTrainedTokenizerBase
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.inputs
import
PromptInputs
from
vllm.inputs
import
PromptInputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.entrypoints.openai.api_server
import
(
build_async_engine_client_from_engine_args
)
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
,
merge_async_iterators
def
sample_requests
(
def
sample_requests
(
...
@@ -85,8 +87,11 @@ def run_vllm(
...
@@ -85,8 +87,11 @@ def run_vllm(
max_num_batched_tokens
:
int
,
max_num_batched_tokens
:
int
,
distributed_executor_backend
:
Optional
[
str
],
distributed_executor_backend
:
Optional
[
str
],
gpu_memory_utilization
:
float
=
0.9
,
gpu_memory_utilization
:
float
=
0.9
,
num_scheduler_steps
:
int
=
1
,
use_v2_block_manager
:
bool
=
False
,
download_dir
:
Optional
[
str
]
=
None
,
download_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
EngineArgs
.
load_format
,
load_format
:
str
=
EngineArgs
.
load_format
,
disable_async_output_proc
:
bool
=
False
,
)
->
float
:
)
->
float
:
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
llm
=
LLM
(
llm
=
LLM
(
...
@@ -109,6 +114,9 @@ def run_vllm(
...
@@ -109,6 +114,9 @@ def run_vllm(
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_batched_tokens
=
max_num_batched_tokens
,
distributed_executor_backend
=
distributed_executor_backend
,
distributed_executor_backend
=
distributed_executor_backend
,
load_format
=
load_format
,
load_format
=
load_format
,
num_scheduler_steps
=
num_scheduler_steps
,
use_v2_block_manager
=
use_v2_block_manager
,
disable_async_output_proc
=
disable_async_output_proc
,
)
)
# Add the requests to the engine.
# Add the requests to the engine.
...
@@ -167,6 +175,93 @@ def run_vllm(
...
@@ -167,6 +175,93 @@ def run_vllm(
return
end
-
start
return
end
-
start
async
def
run_vllm_async
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tokenizer
:
str
,
quantization
:
Optional
[
str
],
tensor_parallel_size
:
int
,
seed
:
int
,
n
:
int
,
use_beam_search
:
bool
,
trust_remote_code
:
bool
,
dtype
:
str
,
max_model_len
:
Optional
[
int
],
enforce_eager
:
bool
,
kv_cache_dtype
:
str
,
quantization_param_path
:
Optional
[
str
],
device
:
str
,
enable_prefix_caching
:
bool
,
enable_chunked_prefill
:
bool
,
max_num_batched_tokens
:
int
,
distributed_executor_backend
:
Optional
[
str
],
gpu_memory_utilization
:
float
=
0.9
,
num_scheduler_steps
:
int
=
1
,
use_v2_block_manager
:
bool
=
False
,
download_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
EngineArgs
.
load_format
,
disable_async_output_proc
:
bool
=
False
,
disable_frontend_multiprocessing
:
bool
=
False
,
)
->
float
:
from
vllm
import
SamplingParams
engine_args
=
AsyncEngineArgs
(
model
=
model
,
tokenizer
=
tokenizer
,
quantization
=
quantization
,
tensor_parallel_size
=
tensor_parallel_size
,
seed
=
seed
,
trust_remote_code
=
trust_remote_code
,
dtype
=
dtype
,
max_model_len
=
max_model_len
,
gpu_memory_utilization
=
gpu_memory_utilization
,
enforce_eager
=
enforce_eager
,
kv_cache_dtype
=
kv_cache_dtype
,
quantization_param_path
=
quantization_param_path
,
device
=
device
,
enable_prefix_caching
=
enable_prefix_caching
,
download_dir
=
download_dir
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
distributed_executor_backend
=
distributed_executor_backend
,
load_format
=
load_format
,
num_scheduler_steps
=
num_scheduler_steps
,
use_v2_block_manager
=
use_v2_block_manager
,
disable_async_output_proc
=
disable_async_output_proc
,
worker_use_ray
=
False
,
engine_use_ray
=
False
,
disable_log_requests
=
True
,
)
async
with
build_async_engine_client_from_engine_args
(
engine_args
,
disable_frontend_multiprocessing
)
as
llm
:
# Add the requests to the engine.
prompts
:
List
[
str
]
=
[]
sampling_params
:
List
[
SamplingParams
]
=
[]
for
prompt
,
_
,
output_len
in
requests
:
prompts
.
append
(
prompt
)
sampling_params
.
append
(
SamplingParams
(
n
=
n
,
temperature
=
0.0
if
use_beam_search
else
1.0
,
top_p
=
1.0
,
use_beam_search
=
use_beam_search
,
ignore_eos
=
True
,
max_tokens
=
output_len
,
))
generators
=
[]
start
=
time
.
perf_counter
()
for
i
,
(
prompt
,
sp
)
in
enumerate
(
zip
(
prompts
,
sampling_params
)):
generator
=
llm
.
generate
(
prompt
,
sp
,
request_id
=
f
"test
{
i
}
"
)
generators
.
append
(
generator
)
all_gens
=
merge_async_iterators
(
*
generators
)
async
for
i
,
res
in
all_gens
:
pass
end
=
time
.
perf_counter
()
return
end
-
start
def
run_hf
(
def
run_hf
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
model
:
str
,
...
@@ -266,15 +361,24 @@ def main(args: argparse.Namespace):
...
@@ -266,15 +361,24 @@ def main(args: argparse.Namespace):
args
.
output_len
)
args
.
output_len
)
if
args
.
backend
==
"vllm"
:
if
args
.
backend
==
"vllm"
:
elapsed_time
=
run_vllm
(
run_args
=
[
warmup_requests
,
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
max_model_len
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
max_model_len
,
args
.
enforce_eager
,
args
.
kv_cache_dtype
,
args
.
enforce_eager
,
args
.
kv_cache_dtype
,
args
.
quantization_param_path
,
args
.
device
,
args
.
quantization_param_path
,
args
.
device
,
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
max_num_batched_tokens
,
args
.
distributed_executor_backend
,
args
.
max_num_batched_tokens
,
args
.
distributed_executor_backend
,
args
.
gpu_memory_utilization
,
args
.
download_dir
,
args
.
load_format
)
args
.
gpu_memory_utilization
,
args
.
num_scheduler_steps
,
args
.
use_v2_block_manager
,
args
.
download_dir
,
args
.
load_format
,
args
.
disable_async_output_proc
]
if
args
.
async_engine
:
run_args
.
append
(
args
.
disable_frontend_multiprocessing
)
elapsed_time
=
uvloop
.
run
(
run_vllm_async
(
*
run_args
))
else
:
elapsed_time
=
run_vllm
(
*
run_args
)
elif
args
.
backend
==
"hf"
:
elif
args
.
backend
==
"hf"
:
assert
args
.
tensor_parallel_size
==
1
assert
args
.
tensor_parallel_size
==
1
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
...
@@ -407,10 +511,18 @@ if __name__ == "__main__":
...
@@ -407,10 +511,18 @@ if __name__ == "__main__":
choices
=
[
"auto"
,
"cuda"
,
"cpu"
,
"openvino"
,
"tpu"
,
"xpu"
],
choices
=
[
"auto"
,
"cuda"
,
"cpu"
,
"openvino"
,
"tpu"
,
"xpu"
],
help
=
'device type for vLLM execution, supporting CUDA, OpenVINO and '
help
=
'device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.'
)
'CPU.'
)
parser
.
add_argument
(
"--num-scheduler-steps"
,
type
=
int
,
default
=
1
,
help
=
"Maximum number of forward steps per scheduler call."
)
parser
.
add_argument
(
"--use-v2-block-manager"
,
action
=
'store_true'
,
help
=
"Enable block manager v2."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-prefix-caching"
,
"--enable-prefix-caching"
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
"
e
nable automatic prefix caching for vLLM backend."
)
help
=
"
E
nable automatic prefix caching for vLLM backend."
)
parser
.
add_argument
(
"--enable-chunked-prefill"
,
parser
.
add_argument
(
"--enable-chunked-prefill"
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
"enable chunked prefill for vLLM backend."
)
help
=
"enable chunked prefill for vLLM backend."
)
...
@@ -459,6 +571,19 @@ if __name__ == "__main__":
...
@@ -459,6 +571,19 @@ if __name__ == "__main__":
'section for more information.
\n
'
'section for more information.
\n
'
'* "bitsandbytes" will load the weights using bitsandbytes '
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.
\n
'
)
'quantization.
\n
'
)
parser
.
add_argument
(
"--disable-async-output-proc"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Disable async output processor for vLLM backend."
)
parser
.
add_argument
(
"--async-engine"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Use vLLM async engine rather than LLM class."
)
parser
.
add_argument
(
"--disable-frontend-multiprocessing"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Disable decoupled async engine frontend."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
tokenizer
is
None
:
if
args
.
tokenizer
is
None
:
args
.
tokenizer
=
args
.
model
args
.
tokenizer
=
args
.
model
...
...
vllm/engine/llm_engine.py
View file @
9b902f9e
...
@@ -315,128 +315,136 @@ class LLMEngine:
...
@@ -315,128 +315,136 @@ class LLMEngine:
observability_config
=
self
.
observability_config
,
observability_config
=
self
.
observability_config
,
)
)
if
not
self
.
model_config
.
embedding_mode
:
init_success
=
False
self
.
_initialize_kv_caches
()
try
:
if
not
self
.
model_config
.
embedding_mode
:
# If usage stat is enabled, collect relevant info.
self
.
_initialize_kv_caches
()
if
is_usage_stats_enabled
():
from
vllm.model_executor.model_loader
import
(
# If usage stat is enabled, collect relevant info.
get_architecture_class_name
)
if
is_usage_stats_enabled
():
usage_message
.
report_usage
(
from
vllm.model_executor.model_loader
import
(
get_architecture_class_name
(
model_config
),
get_architecture_class_name
)
usage_context
,
usage_message
.
report_usage
(
extra_kvs
=
{
get_architecture_class_name
(
model_config
),
# Common configuration
usage_context
,
"dtype"
:
extra_kvs
=
{
str
(
model_config
.
dtype
),
# Common configuration
"tensor_parallel_size"
:
"dtype"
:
parallel_config
.
tensor_parallel_size
,
str
(
model_config
.
dtype
),
"block_size"
:
"tensor_parallel_size"
:
cache_config
.
block_size
,
parallel_config
.
tensor_parallel_size
,
"gpu_memory_utilization"
:
"block_size"
:
cache_config
.
gpu_memory_utilization
,
cache_config
.
block_size
,
"gpu_memory_utilization"
:
# Quantization
cache_config
.
gpu_memory_utilization
,
"quantization"
:
model_config
.
quantization
,
# Quantization
"kv_cache_dtype"
:
"quantization"
:
str
(
cache_config
.
cache_dtype
),
model_config
.
quantization
,
"kv_cache_dtype"
:
# Feature flags
str
(
cache_config
.
cache_dtype
),
"enable_lora"
:
bool
(
lora_config
),
# Feature flags
"enable_prompt_adapter"
:
"enable_lora"
:
bool
(
prompt_adapter_config
),
bool
(
lora_config
),
"enable_prefix_caching"
:
"enable_prompt_adapter"
:
cache_config
.
enable_prefix_caching
,
bool
(
prompt_adapter_config
),
"enforce_eager"
:
"enable_prefix_caching"
:
model_config
.
enforce_eager
,
cache_config
.
enable_prefix_caching
,
"disable_custom_all_reduce"
:
"enforce_eager"
:
parallel_config
.
disable_custom_all_reduce
,
model_config
.
enforce_eager
,
})
"disable_custom_all_reduce"
:
parallel_config
.
disable_custom_all_reduce
,
if
self
.
tokenizer
:
})
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
if
self
.
tokenizer
:
self
.
tokenizer
.
ping
()
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self
.
cached_scheduler_outputs
=
[
self
.
tokenizer
.
ping
()
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
self
.
cached_scheduler_outputs
=
[
]
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
self
.
scheduler_contexts
=
[
]
SchedulerContext
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
self
.
scheduler_contexts
=
[
]
SchedulerContext
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
self
.
async_callbacks
=
[
]
functools
.
partial
(
self
.
_process_model_outputs
,
ctx
=
self
.
scheduler_contexts
[
v_id
])
self
.
async_callbacks
=
[
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
functools
.
partial
(
self
.
_process_model_outputs
,
]
ctx
=
self
.
scheduler_contexts
[
v_id
])
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
# Currently used by AsyncLLMEngine to ensure quick append
]
# of request outputs to asyncio queues
self
.
process_request_outputs_callback
=
None
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
# Create the scheduler.
self
.
process_request_outputs_callback
=
None
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
# Create the scheduler.
self
.
scheduler
=
[
# NOTE: the cache_config here have been updated with the numbers of
Scheduler
(
# GPU and CPU blocks, which are profiled in the distributed executor.
scheduler_config
,
cache_config
,
lora_config
,
self
.
scheduler
=
[
parallel_config
.
pipeline_parallel_size
,
Scheduler
(
self
.
async_callbacks
[
v_id
]
scheduler_config
,
cache_config
,
lora_config
,
if
model_config
.
use_async_output_proc
else
None
)
parallel_config
.
pipeline_parallel_size
,
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
self
.
async_callbacks
[
v_id
]
]
if
model_config
.
use_async_output_proc
else
None
)
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
# Metric Logging.
]
if
self
.
log_stats
:
if
stat_loggers
is
not
None
:
# Metric Logging.
self
.
stat_loggers
=
stat_loggers
if
self
.
log_stats
:
else
:
if
stat_loggers
is
not
None
:
# Lazy import for prometheus multiprocessing.
self
.
stat_loggers
=
stat_loggers
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
else
:
# before prometheus_client is imported.
# Lazy import for prometheus multiprocessing.
# See https://prometheus.github.io/client_python/multiprocess/
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
# before prometheus_client is imported.
PrometheusStatLogger
)
# See https://prometheus.github.io/client_python/multiprocess/
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
self
.
stat_loggers
=
{
PrometheusStatLogger
)
"logging"
:
LoggingStatLogger
(
self
.
stat_loggers
=
{
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
),
"logging"
:
"prometheus"
:
LoggingStatLogger
(
PrometheusStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
),
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
"prometheus"
:
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
PrometheusStatLogger
(
max_model_len
=
self
.
model_config
.
max_model_len
),
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
}
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
self
.
stat_loggers
[
"prometheus"
].
info
(
"cache_config"
,
max_model_len
=
self
.
model_config
.
max_model_len
),
self
.
cache_config
)
}
self
.
stat_loggers
[
"prometheus"
].
info
(
"cache_config"
,
self
.
tracer
=
None
self
.
cache_config
)
if
self
.
observability_config
.
otlp_traces_endpoint
:
self
.
tracer
=
init_tracer
(
self
.
tracer
=
None
"vllm.llm_engine"
,
if
self
.
observability_config
.
otlp_traces_endpoint
:
self
.
observability_config
.
otlp_traces_endpoint
)
self
.
tracer
=
init_tracer
(
"vllm.llm_engine"
,
# Create sequence output processor, e.g. for beam search or
self
.
observability_config
.
otlp_traces_endpoint
)
# speculative decoding.
self
.
output_processor
=
(
# Create sequence output processor, e.g. for beam search or
SequenceGroupOutputProcessor
.
create_output_processor
(
# speculative decoding.
self
.
scheduler_config
,
self
.
output_processor
=
(
self
.
detokenizer
,
SequenceGroupOutputProcessor
.
create_output_processor
(
self
.
scheduler
,
self
.
scheduler_config
,
self
.
seq_counter
,
self
.
detokenizer
,
get_tokenizer_for_seq
,
self
.
scheduler
,
stop_checker
=
StopChecker
(
self
.
seq_counter
,
self
.
scheduler_config
.
max_model_len
,
get_tokenizer_for_seq
,
get_tokenizer_for_seq
,
),
stop_checker
=
StopChecker
(
))
self
.
scheduler_config
.
max_model_len
,
get_tokenizer_for_seq
,
),
))
init_success
=
True
finally
:
if
not
init_success
:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self
.
model_executor
.
shutdown
()
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
...
vllm/executor/multiproc_worker_utils.py
View file @
9b902f9e
...
@@ -131,8 +131,7 @@ class WorkerMonitor(threading.Thread):
...
@@ -131,8 +131,7 @@ class WorkerMonitor(threading.Thread):
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
died_count
+=
1
died_count
+=
1
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
process
.
name
,
process
.
pid
,
process
.
name
,
process
.
pid
,
process
.
exitcode
)
process
.
exitcode
)
if
died_count
<
len
(
self
.
workers
):
if
died_count
<
len
(
self
.
workers
):
logger
.
info
(
logger
.
info
(
"Killing remaining local vLLM worker processes"
)
"Killing remaining local vLLM worker processes"
)
...
...
vllm/worker/model_runner.py
View file @
9b902f9e
...
@@ -1106,12 +1106,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1106,12 +1106,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
batch_size
+=
seq_len
batch_size
+=
seq_len
seq_data
,
dummy_multi_modal_data
=
INPUT_REGISTRY
\
seq_data
,
dummy_multi_modal_data
=
INPUT_REGISTRY
\
.
dummy_data_for_profiling
(
model_config
,
seq_len
)
.
dummy_data_for_profiling
(
self
.
model_config
,
seq_len
,
# Having more tokens is over-conservative but otherwise fine
self
.
mm_registry
)
assert
len
(
seq_data
.
prompt_token_ids
)
>=
seq_len
,
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but got:
{
len
(
seq_data
.
prompt_token_ids
)
}
"
)
seq
=
SequenceGroupMetadata
(
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
request_id
=
str
(
group_id
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment