Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
df5dafaa
Unverified
Commit
df5dafaa
authored
Jan 25, 2025
by
Cyrus Leung
Committed by
GitHub
Jan 24, 2025
Browse files
[Misc] Remove deprecated code (#12383)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
ab5bbf5a
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
25 additions
and
78 deletions
+25
-78
tests/async_engine/test_api_server.py
tests/async_engine/test_api_server.py
+14
-9
tests/basic_correctness/test_preemption.py
tests/basic_correctness/test_preemption.py
+9
-9
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+2
-1
vllm/config.py
vllm/config.py
+0
-10
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-6
vllm/engine/metrics.py
vllm/engine/metrics.py
+0
-43
No files found.
tests/async_engine/test_api_server.py
View file @
df5dafaa
...
@@ -25,27 +25,32 @@ def _query_server_long(prompt: str) -> dict:
...
@@ -25,27 +25,32 @@ def _query_server_long(prompt: str) -> dict:
@
pytest
.
fixture
@
pytest
.
fixture
def
api_server
(
tokenizer_pool_size
:
int
,
worker_use_ray
:
bool
):
def
api_server
(
tokenizer_pool_size
:
int
,
distributed_executor_backend
:
str
):
script_path
=
Path
(
__file__
).
parent
.
joinpath
(
script_path
=
Path
(
__file__
).
parent
.
joinpath
(
"api_server_async_engine.py"
).
absolute
()
"api_server_async_engine.py"
).
absolute
()
commands
=
[
commands
=
[
sys
.
executable
,
"-u"
,
sys
.
executable
,
str
(
script_path
),
"--model"
,
"facebook/opt-125m"
,
"--host"
,
"-u"
,
"127.0.0.1"
,
"--tokenizer-pool-size"
,
str
(
script_path
),
str
(
tokenizer_pool_size
)
"--model"
,
"facebook/opt-125m"
,
"--host"
,
"127.0.0.1"
,
"--tokenizer-pool-size"
,
str
(
tokenizer_pool_size
),
"--distributed-executor-backend"
,
distributed_executor_backend
,
]
]
if
worker_use_ray
:
commands
.
append
(
"--worker-use-ray"
)
uvicorn_process
=
subprocess
.
Popen
(
commands
)
uvicorn_process
=
subprocess
.
Popen
(
commands
)
yield
yield
uvicorn_process
.
terminate
()
uvicorn_process
.
terminate
()
@
pytest
.
mark
.
parametrize
(
"tokenizer_pool_size"
,
[
0
,
2
])
@
pytest
.
mark
.
parametrize
(
"tokenizer_pool_size"
,
[
0
,
2
])
@
pytest
.
mark
.
parametrize
(
"
worker_use_ray"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"
distributed_executor_backend"
,
[
"mp"
,
"ray"
])
def
test_api_server
(
api_server
,
tokenizer_pool_size
:
int
,
def
test_api_server
(
api_server
,
tokenizer_pool_size
:
int
,
worker_use_ray
:
bool
):
distributed_executor_backend
:
str
):
"""
"""
Run the API server and test it.
Run the API server and test it.
...
...
tests/basic_correctness/test_preemption.py
View file @
df5dafaa
...
@@ -29,10 +29,10 @@ def check_settings():
...
@@ -29,10 +29,10 @@ def check_settings():
@
pytest
.
fixture
@
pytest
.
fixture
def
worker_use_ray
()
->
bool
:
def
distributed_executor_backend
()
->
str
:
# When SPMD worker is used, use
ray_use_worker=True
# When SPMD worker is used, use
distributed_executor_backend="ray"
# to test delta input optimization works with preemption.
# to test delta input optimization works with preemption.
return
envs
.
VLLM_USE_RAY_SPMD_WORKER
return
"ray"
if
envs
.
VLLM_USE_RAY_SPMD_WORKER
else
"mp"
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
...
@@ -47,7 +47,7 @@ def test_chunked_prefill_recompute(
...
@@ -47,7 +47,7 @@ def test_chunked_prefill_recompute(
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
chunked_prefill_token_size
:
int
,
chunked_prefill_token_size
:
int
,
worker_use_ray
:
bool
,
distributed_executor_backend
:
str
,
)
->
None
:
)
->
None
:
"""Ensure that chunked prefill works with preemption."""
"""Ensure that chunked prefill works with preemption."""
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
...
@@ -66,7 +66,7 @@ def test_chunked_prefill_recompute(
...
@@ -66,7 +66,7 @@ def test_chunked_prefill_recompute(
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_batched_tokens
=
max_num_batched_tokens
,
enable_chunked_prefill
=
enable_chunked_prefill
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_seqs
=
max_num_seqs
,
max_num_seqs
=
max_num_seqs
,
worker_use_ray
=
worker_use_ray
,
distributed_executor_backend
=
distributed_executor_backend
,
disable_log_stats
=
False
,
disable_log_stats
=
False
,
)
as
vllm_model
:
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
...
@@ -93,7 +93,7 @@ def test_preemption(
...
@@ -93,7 +93,7 @@ def test_preemption(
model
:
str
,
model
:
str
,
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
worker_use_ray
:
bool
,
distributed_executor_backend
:
str
,
)
->
None
:
)
->
None
:
"""By default, recompute preemption is enabled"""
"""By default, recompute preemption is enabled"""
...
@@ -104,7 +104,7 @@ def test_preemption(
...
@@ -104,7 +104,7 @@ def test_preemption(
model
,
model
,
dtype
=
dtype
,
dtype
=
dtype
,
disable_log_stats
=
False
,
disable_log_stats
=
False
,
worker_use_ray
=
worker_use_ray
,
distributed_executor_backend
=
distributed_executor_backend
,
)
as
vllm_model
:
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
assert
(
vllm_model
.
model
.
llm_engine
.
scheduler
[
0
].
artificial_preempt_cnt
assert
(
vllm_model
.
model
.
llm_engine
.
scheduler
[
0
].
artificial_preempt_cnt
...
@@ -144,7 +144,7 @@ def test_preemption_infeasible(
...
@@ -144,7 +144,7 @@ def test_preemption_infeasible(
model
:
str
,
model
:
str
,
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
worker_use_ray
:
bool
,
distributed_executor_backend
:
str
,
)
->
None
:
)
->
None
:
"""Verify infeasible preemption request will be ignored."""
"""Verify infeasible preemption request will be ignored."""
BLOCK_SIZE
=
16
BLOCK_SIZE
=
16
...
@@ -159,7 +159,7 @@ def test_preemption_infeasible(
...
@@ -159,7 +159,7 @@ def test_preemption_infeasible(
# ignored instead of hanging forever.
# ignored instead of hanging forever.
num_gpu_blocks_override
=
prefill_blocks
+
decode_blocks
//
2
,
num_gpu_blocks_override
=
prefill_blocks
+
decode_blocks
//
2
,
max_model_len
=
((
prefill_blocks
+
decode_blocks
//
2
)
*
BLOCK_SIZE
),
max_model_len
=
((
prefill_blocks
+
decode_blocks
//
2
)
*
BLOCK_SIZE
),
worker_use_ray
=
worker_use_ray
,
distributed_executor_backend
=
distributed_executor_backend
,
)
as
vllm_model
:
)
as
vllm_model
:
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
ignore_eos
=
True
)
ignore_eos
=
True
)
...
...
tests/multi_step/test_correctness_async_llm.py
View file @
df5dafaa
...
@@ -16,7 +16,8 @@ NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
...
@@ -16,7 +16,8 @@ NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
NUM_PROMPTS
=
[
10
]
NUM_PROMPTS
=
[
10
]
DEFAULT_SERVER_ARGS
:
List
[
str
]
=
[
DEFAULT_SERVER_ARGS
:
List
[
str
]
=
[
"--worker-use-ray"
,
"--distributed-executor-backend"
,
"ray"
,
"--gpu-memory-utilization"
,
"--gpu-memory-utilization"
,
"0.85"
,
"0.85"
,
"--swap-space"
,
"--swap-space"
,
...
...
vllm/config.py
View file @
df5dafaa
...
@@ -1227,9 +1227,6 @@ class ParallelConfig:
...
@@ -1227,9 +1227,6 @@ class ParallelConfig:
pipeline_parallel_size
:
int
=
1
# Number of pipeline parallel groups.
pipeline_parallel_size
:
int
=
1
# Number of pipeline parallel groups.
tensor_parallel_size
:
int
=
1
# Number of tensor parallel groups.
tensor_parallel_size
:
int
=
1
# Number of tensor parallel groups.
# Deprecated, use distributed_executor_backend instead.
worker_use_ray
:
Optional
[
bool
]
=
None
# Maximum number of multiple batches
# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
# when load model sequentially. To avoid RAM OOM when using tensor
# parallel and large models.
# parallel and large models.
...
@@ -1283,13 +1280,6 @@ class ParallelConfig:
...
@@ -1283,13 +1280,6 @@ class ParallelConfig:
self
.
world_size
=
self
.
pipeline_parallel_size
*
\
self
.
world_size
=
self
.
pipeline_parallel_size
*
\
self
.
tensor_parallel_size
self
.
tensor_parallel_size
if
self
.
worker_use_ray
:
if
self
.
distributed_executor_backend
is
None
:
self
.
distributed_executor_backend
=
"ray"
elif
not
self
.
use_ray
:
raise
ValueError
(
f
"worker-use-ray can't be used with "
f
"distributed executor backend "
f
"'
{
self
.
distributed_executor_backend
}
'."
)
ray_only_devices
=
[
"tpu"
]
ray_only_devices
=
[
"tpu"
]
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
(
current_platform
.
device_type
in
ray_only_devices
if
(
current_platform
.
device_type
in
ray_only_devices
...
...
vllm/engine/arg_utils.py
View file @
df5dafaa
...
@@ -100,7 +100,6 @@ class EngineArgs:
...
@@ -100,7 +100,6 @@ class EngineArgs:
kv_cache_dtype
:
str
=
'auto'
kv_cache_dtype
:
str
=
'auto'
seed
:
int
=
0
seed
:
int
=
0
max_model_len
:
Optional
[
int
]
=
None
max_model_len
:
Optional
[
int
]
=
None
worker_use_ray
:
bool
=
False
# Note: Specifying a custom executor backend by passing a class
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# is intended for expert use only. The API may change without
# notice.
# notice.
...
@@ -389,10 +388,6 @@ class EngineArgs:
...
@@ -389,10 +388,6 @@ class EngineArgs:
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
'only supports Ray for distributed inference.'
)
'only supports Ray for distributed inference.'
)
parser
.
add_argument
(
'--worker-use-ray'
,
action
=
'store_true'
,
help
=
'Deprecated, use ``--distributed-executor-backend=ray``.'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
'-pp'
,
type
=
int
,
type
=
int
,
...
@@ -1071,7 +1066,6 @@ class EngineArgs:
...
@@ -1071,7 +1066,6 @@ class EngineArgs:
parallel_config
=
ParallelConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
worker_use_ray
=
self
.
worker_use_ray
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
tokenizer_pool_config
=
TokenizerPoolConfig
.
create_config
(
tokenizer_pool_config
=
TokenizerPoolConfig
.
create_config
(
...
...
vllm/engine/metrics.py
View file @
df5dafaa
...
@@ -259,21 +259,6 @@ class Metrics:
...
@@ -259,21 +259,6 @@ class Metrics:
documentation
=
"Number of emitted tokens."
,
documentation
=
"Number of emitted tokens."
,
labelnames
=
labelnames
))
labelnames
=
labelnames
))
# Deprecated in favor of vllm:prompt_tokens_total
self
.
gauge_avg_prompt_throughput
=
self
.
_gauge_cls
(
name
=
"vllm:avg_prompt_throughput_toks_per_s"
,
documentation
=
"Average prefill throughput in tokens/s."
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
,
)
# Deprecated in favor of vllm:generation_tokens_total
self
.
gauge_avg_generation_throughput
=
self
.
_gauge_cls
(
name
=
"vllm:avg_generation_throughput_toks_per_s"
,
documentation
=
"Average generation throughput in tokens/s."
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
,
)
# end-metrics-definitions
# end-metrics-definitions
...
@@ -635,20 +620,6 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -635,20 +620,6 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
_log_histogram
(
self
.
metrics
.
histogram_max_tokens_request
,
self
.
_log_histogram
(
self
.
metrics
.
histogram_max_tokens_request
,
stats
.
max_tokens_requests
)
stats
.
max_tokens_requests
)
def
_log_prometheus_interval
(
self
,
prompt_throughput
:
float
,
generation_throughput
:
float
)
->
None
:
# Logs metrics to prometheus that are computed every logging_interval.
# Support legacy gauge metrics that make throughput calculations on
# the vLLM side. Moving forward, we should use counters like
# counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the
# grafana/prometheus side. See
# https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
self
.
metrics
.
gauge_avg_prompt_throughput
.
labels
(
**
self
.
labels
).
set
(
prompt_throughput
)
self
.
metrics
.
gauge_avg_generation_throughput
.
labels
(
**
self
.
labels
).
set
(
generation_throughput
)
def
log
(
self
,
stats
:
Stats
):
def
log
(
self
,
stats
:
Stats
):
"""Logs to prometheus and tracked stats every iteration."""
"""Logs to prometheus and tracked stats every iteration."""
# Log to prometheus.
# Log to prometheus.
...
@@ -664,20 +635,6 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -664,20 +635,6 @@ class PrometheusStatLogger(StatLoggerBase):
# Log locally every local_interval seconds.
# Log locally every local_interval seconds.
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
self
.
local_interval
):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput
=
get_throughput
(
self
.
num_prompt_tokens
,
now
=
stats
.
now
,
last_log
=
self
.
last_local_log
)
generation_throughput
=
get_throughput
(
self
.
num_generation_tokens
,
now
=
stats
.
now
,
last_log
=
self
.
last_local_log
)
self
.
_log_prometheus_interval
(
prompt_throughput
=
prompt_throughput
,
generation_throughput
=
generation_throughput
)
if
self
.
spec_decode_metrics
is
not
None
:
if
self
.
spec_decode_metrics
is
not
None
:
self
.
_log_gauge
(
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
,
self
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment