Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a31cab75
Unverified
Commit
a31cab75
authored
Jun 06, 2024
by
Antoni Baum
Committed by
GitHub
Jun 06, 2024
Browse files
[Core] Avoid copying prompt/output tokens if no penalties are used (#5289)
parent
828da0d4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
30 deletions
+50
-30
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+50
-30
No files found.
vllm/model_executor/sampling_metadata.py
View file @
a31cab75
...
@@ -386,16 +386,18 @@ class SamplingTensors:
...
@@ -386,16 +386,18 @@ class SamplingTensors:
presence_penalties
+=
[
0
]
*
prefill_len
presence_penalties
+=
[
0
]
*
prefill_len
frequency_penalties
+=
[
0
]
*
prefill_len
frequency_penalties
+=
[
0
]
*
prefill_len
repetition_penalties
+=
[
1
]
*
prefill_len
repetition_penalties
+=
[
1
]
*
prefill_len
prompt_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
if
do_penalties
:
output_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
prompt_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
output_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
if
seq_group
.
do_sample
:
if
seq_group
.
do_sample
:
sample_lens
=
len
(
seq_group
.
sample_indices
)
sample_lens
=
len
(
seq_group
.
sample_indices
)
assert
sample_lens
==
len
(
seq_ids
)
assert
sample_lens
==
len
(
seq_ids
)
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
seq_data
=
seq_group
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
if
do_penalties
:
output_tokens
.
append
(
seq_data
.
output_token_ids
)
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
...
@@ -443,18 +445,22 @@ class SamplingTensors:
...
@@ -443,18 +445,22 @@ class SamplingTensors:
# Note that the performance will be very bad without
# Note that the performance will be very bad without
# pinned memory.
# pinned memory.
pin_memory
=
is_pin_memory_available
()
pin_memory
=
is_pin_memory_available
()
prompt_max_len
=
max
([
len
(
tokens
)
for
tokens
in
prompt_tokens
],
default
=
0
)
do_penalties
=
prompt_tokens
or
output_tokens
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
if
do_penalties
:
for
tokens
in
prompt_tokens
prompt_max_len
=
max
([
len
(
tokens
)
for
tokens
in
prompt_tokens
],
]
default
=
0
)
output_max_len
=
max
([
len
(
tokens
)
for
tokens
in
output_tokens
],
prompt_padded_tokens
=
[
default
=
0
)
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
output_padded_tokens
=
[
for
tokens
in
prompt_tokens
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
]
for
tokens
in
output_tokens
output_max_len
=
max
([
len
(
tokens
)
for
tokens
in
output_tokens
],
]
default
=
0
)
output_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
for
tokens
in
output_tokens
]
temperatures_t
=
torch
.
tensor
(
temperatures_t
=
torch
.
tensor
(
temperatures
,
temperatures
,
...
@@ -504,18 +510,22 @@ class SamplingTensors:
...
@@ -504,18 +510,22 @@ class SamplingTensors:
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
)
)
prompt_tensor
=
torch
.
tensor
(
if
do_penalties
:
prompt_padded_tokens
,
prompt_tensor
=
torch
.
tensor
(
device
=
"cpu"
,
prompt_padded_tokens
,
dtype
=
torch
.
long
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
dtype
=
torch
.
long
,
)
pin_memory
=
pin_memory
,
output_tensor
=
torch
.
tensor
(
)
output_padded_tokens
,
output_tensor
=
torch
.
tensor
(
device
=
"cpu"
,
output_padded_tokens
,
dtype
=
torch
.
long
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
dtype
=
torch
.
long
,
)
pin_memory
=
pin_memory
,
)
else
:
prompt_tensor
=
None
output_tensor
=
None
# need to transpose and make contiguous to
# need to transpose and make contiguous to
# copy the tensor correctly.
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
# [batch_size, n_seeds] -> [n_seeds, batch_size]
...
@@ -538,6 +548,16 @@ class SamplingTensors:
...
@@ -538,6 +548,16 @@ class SamplingTensors:
extra_seeds_gpu
=
None
extra_seeds_gpu
=
None
sampling_seeds_gpu
=
sampling_seeds_gpu
[:
num_base_seeds
]
sampling_seeds_gpu
=
sampling_seeds_gpu
[:
num_base_seeds
]
if
do_penalties
:
prompt_tokens_gpu
=
prompt_tensor
.
to
(
device
=
device
,
non_blocking
=
True
)
output_tokens_gpu
=
output_tensor
.
to
(
device
=
device
,
non_blocking
=
True
)
else
:
empty_tensor
=
torch
.
empty
(
0
,
device
=
device
,
dtype
=
torch
.
long
)
prompt_tokens_gpu
=
empty_tensor
output_tokens_gpu
=
empty_tensor
return
cls
(
return
cls
(
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
...
@@ -549,8 +569,8 @@ class SamplingTensors:
...
@@ -549,8 +569,8 @@ class SamplingTensors:
non_blocking
=
True
),
non_blocking
=
True
),
repetition_penalties
=
repetition_penalties_t
.
to
(
device
=
device
,
repetition_penalties
=
repetition_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
non_blocking
=
True
),
prompt_tokens
=
prompt_tens
or
.
to
(
device
=
device
,
non_blocking
=
True
)
,
prompt_tokens
=
prompt_t
ok
ens
_gpu
,
output_tokens
=
output_tens
or
.
to
(
device
=
device
,
non_blocking
=
True
)
,
output_tokens
=
output_t
ok
ens
_gpu
,
sampling_seeds
=
sampling_seeds_gpu
,
sampling_seeds
=
sampling_seeds_gpu
,
sample_indices
=
sample_indices_t
.
to
(
device
=
device
,
sample_indices
=
sample_indices_t
.
to
(
device
=
device
,
non_blocking
=
True
),
non_blocking
=
True
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment