Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2bb0489c
Unverified
Commit
2bb0489c
authored
Jul 16, 2024
by
Peng Guanwen
Committed by
GitHub
Jul 16, 2024
Browse files
[Core] Use numpy to speed up padded token processing (#6442)
parent
7508a3dc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
20 deletions
+18
-20
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+18
-20
No files found.
vllm/model_executor/sampling_metadata.py
View file @
2bb0489c
...
...
@@ -2,6 +2,7 @@ import random
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
vllm.model_executor.layers.ops.sample
import
get_num_triton_sampler_splits
...
...
@@ -457,16 +458,20 @@ class SamplingTensors:
if
do_penalties
:
prompt_max_len
=
max
([
len
(
tokens
)
for
tokens
in
prompt_tokens
],
default
=
0
)
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
for
tokens
in
prompt_tokens
]
prompt_padded_tokens
=
np
.
full
(
(
len
(
prompt_tokens
),
prompt_max_len
),
vocab_size
,
dtype
=
np
.
int64
)
for
i
,
tokens
in
enumerate
(
prompt_tokens
):
prompt_padded_tokens
[
i
,
:
len
(
tokens
)]
=
tokens
output_max_len
=
max
([
len
(
tokens
)
for
tokens
in
output_tokens
],
default
=
0
)
output_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
for
tokens
in
output_tokens
]
output_padded_tokens
=
np
.
full
(
(
len
(
output_tokens
),
output_max_len
),
vocab_size
,
dtype
=
np
.
int64
)
for
i
,
tokens
in
enumerate
(
output_tokens
):
output_padded_tokens
[
i
,
:
len
(
tokens
)]
=
tokens
temperatures_t
=
torch
.
tensor
(
temperatures
,
...
...
@@ -517,18 +522,11 @@ class SamplingTensors:
pin_memory
=
pin_memory
,
)
if
do_penalties
:
prompt_tensor
=
torch
.
tensor
(
prompt_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
output_tensor
=
torch
.
tensor
(
output_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
prompt_tensor
=
torch
.
from_numpy
(
prompt_padded_tokens
)
output_tensor
=
torch
.
from_numpy
(
output_padded_tokens
)
if
pin_memory
:
prompt_tensor
=
prompt_tensor
.
pin_memory
()
output_tensor
=
output_tensor
.
pin_memory
()
else
:
prompt_tensor
=
None
output_tensor
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment