Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
91fce82c
Unverified
Commit
91fce82c
authored
Oct 11, 2023
by
yhlskt23
Committed by
GitHub
Oct 10, 2023
Browse files
change the timing of sorting logits (#1309)
parent
ac5cf86a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
24 deletions
+16
-24
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+16
-24
No files found.
vllm/model_executor/layers/sampler.py
View file @
91fce82c
...
...
@@ -102,30 +102,24 @@ def _prune_hidden_states(
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
last_token_indices
=
{
t
:
[]
for
t
in
SamplingType
}
last_token_indices
=
[]
start_idx
=
0
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
sampling_type
=
sampling_params
.
sampling_type
seq_ids
,
_
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
assert
len
(
seq_ids
)
==
1
,
"Prompt input should have only one seq."
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
last_token_indices
[
sampling_type
].
append
(
start_idx
+
prompt_len
-
1
)
last_token_indices
.
append
(
start_idx
+
prompt_len
-
1
)
start_idx
+=
prompt_len
else
:
num_seqs
=
len
(
seq_ids
)
last_token_indices
[
sampling_type
].
extend
(
range
(
start_idx
,
start_idx
+
num_seqs
))
last_token_indices
.
extend
(
range
(
start_idx
,
start_idx
+
num_seqs
))
start_idx
+=
num_seqs
all_last_token_indices
=
[]
for
sampling_type
in
SamplingType
:
all_last_token_indices
.
extend
(
last_token_indices
[
sampling_type
])
all_last_token_indices
=
torch
.
tensor
(
all_last_token_indices
,
dtype
=
torch
.
long
,
device
=
hidden_states
.
device
)
return
hidden_states
.
index_select
(
0
,
all_last_token_indices
)
last_token_indices
=
torch
.
tensor
(
last_token_indices
,
dtype
=
torch
.
long
,
device
=
hidden_states
.
device
)
return
hidden_states
.
index_select
(
0
,
last_token_indices
)
def
_get_penalties
(
...
...
@@ -424,27 +418,26 @@ def _sample(
input_metadata
:
InputMetadata
,
)
->
SamplerOutput
:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
category_num_tokens
=
{
t
:
0
for
t
in
SamplingType
}
start_idx
=
0
categorized_seq_ids
=
{
t
:
[]
for
t
in
SamplingType
}
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
num_seqs
=
len
(
seq_ids
)
category_num_tokens
[
sampling_type
]
+=
num_seqs
categorized_seq_ids
[
sampling_type
].
extend
(
range
(
start_idx
,
start_idx
+
num_seqs
))
start_idx
+=
num_seqs
seq_outputs_dict
:
Dict
[
int
,
List
[
SequenceOutputs
]]
=
{}
category_start_idx
=
0
for
sampling_type
in
SamplingType
:
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
input_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
input_metadata
.
num_prompts
for
i
in
seq_group_ids
]
num_tokens
=
categor
y_num_token
s
[
sampling_type
]
num_tokens
=
len
(
categor
ized_seq_id
s
[
sampling_type
]
)
if
num_tokens
==
0
:
continue
category_logprobs
=
logprobs
[
category_start_idx
:
category_start_idx
+
num_tokens
]
category_probs
=
probs
[
category_start_idx
:
category_start_idx
+
num_tokens
]
category_logprobs
=
logprobs
[
categorized_seq_ids
[
sampling_type
]]
category_probs
=
probs
[
categorized_seq_ids
[
sampling_type
]]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
category_logprobs
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
...
...
@@ -497,6 +490,5 @@ def _sample(
sample_idx
+=
num_parent_seqs
result_idx
+=
num_results
assert
sample_idx
==
num_tokens
category_start_idx
+=
num_tokens
return
[
seq_outputs_dict
[
i
]
for
i
in
range
(
len
(
input_metadata
.
seq_groups
))]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment