Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
437b76ff
Unverified
Commit
437b76ff
authored
Feb 24, 2025
by
Roger Wang
Committed by
GitHub
Feb 24, 2025
Browse files
[V1][Core] Fix memory issue with logits & sampling (#13721)
parent
f90a3755
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
29 deletions
+49
-29
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+39
-29
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+10
-0
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
437b76ff
...
...
@@ -1179,6 +1179,43 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return
hidden_states
@
torch
.
inference_mode
()
def
_dummy_sampler_run
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
num_reqs
=
logits
.
size
(
0
)
dummy_tensors
=
lambda
v
:
torch
.
full
(
(
num_reqs
,
),
v
,
device
=
self
.
device
)
dummy_metadata
=
SamplingMetadata
(
temperature
=
dummy_tensors
(
0.5
),
all_greedy
=
False
,
all_random
=
False
,
spec_token_ids
=
None
,
top_p
=
dummy_tensors
(
0.9
),
top_k
=
dummy_tensors
(
logits
.
size
(
1
)
-
1
),
min_p
=
None
,
generators
=
{},
max_num_logprobs
=
None
,
no_penalties
=
True
,
prompt_token_ids
=
None
,
frequency_penalties
=
dummy_tensors
(
0.1
),
presence_penalties
=
dummy_tensors
(
0.1
),
repetition_penalties
=
dummy_tensors
(
0.1
),
output_token_ids
=
[[]
for
_
in
range
(
num_reqs
)],
min_tokens
=
{},
logit_bias
=
[
None
for
_
in
range
(
num_reqs
)],
allowed_token_ids_mask
=
None
,
)
sampler_output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
dummy_metadata
)
return
sampler_output
def
profile_run
(
self
)
->
None
:
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
...
...
@@ -1306,38 +1343,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dummy_kv_caches
)
if
get_pp_group
().
is_last_rank
:
hidden_states
=
hidden_states
[
logit_indices
]
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
dummy_tensors
=
lambda
v
:
torch
.
full
(
(
num_reqs
,
),
v
,
device
=
self
.
device
)
dummy_metadata
=
SamplingMetadata
(
temperature
=
dummy_tensors
(
0.5
),
all_greedy
=
False
,
all_random
=
False
,
spec_token_ids
=
None
,
top_p
=
dummy_tensors
(
0.9
),
top_k
=
dummy_tensors
(
logits
.
size
(
1
)
-
1
),
min_p
=
None
,
generators
=
{},
max_num_logprobs
=
None
,
no_penalties
=
True
,
prompt_token_ids
=
torch
.
ones_like
(
logits
,
dtype
=
torch
.
int64
),
frequency_penalties
=
dummy_tensors
(
0.1
),
presence_penalties
=
dummy_tensors
(
0.1
),
repetition_penalties
=
dummy_tensors
(
0.1
),
output_token_ids
=
[[]
for
_
in
range
(
num_reqs
)],
min_tokens
=
{},
logit_bias
=
[
None
for
_
in
range
(
num_reqs
)],
allowed_token_ids_mask
=
None
,
)
sampler_output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
dummy_metadata
)
sampler_output
=
self
.
_dummy_sampler_run
(
hidden_states
)
else
:
logits
=
None
sampler_output
=
None
dummy_metadata
=
None
torch
.
cuda
.
synchronize
()
del
hidden_states
,
logits
,
sampler_output
,
dummy_metadata
del
hidden_states
,
sampler_output
self
.
encoder_cache
.
clear
()
gc
.
collect
()
...
...
vllm/v1/worker/gpu_worker.py
View file @
437b76ff
...
...
@@ -211,6 +211,16 @@ class Worker(WorkerBase):
self
.
model_runner
.
_dummy_run
(
size
)
if
not
self
.
model_config
.
enforce_eager
:
self
.
model_runner
.
capture_model
()
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
self
.
model_runner
.
_dummy_sampler_run
(
hidden_states
=
self
.
model_runner
.
_dummy_run
(
num_tokens
=
self
.
scheduler_config
.
max_num_seqs
))
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment