Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
SIYIXNI
vllm
Commits
91fce82c
"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "2e754b87af8b2bdb3dfb4e6838426ad7654ba591"
Unverified
Commit
91fce82c
authored
Oct 11, 2023
by
yhlskt23
Committed by
GitHub
Oct 10, 2023
Browse files
change the timing of sorting logits (#1309)
parent
ac5cf86a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
24 deletions
+16
-24
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+16
-24
No files found.
vllm/model_executor/layers/sampler.py
View file @
91fce82c
...
@@ -102,30 +102,24 @@ def _prune_hidden_states(
...
@@ -102,30 +102,24 @@ def _prune_hidden_states(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
last_token_indices
=
{
t
:
[]
for
t
in
SamplingType
}
last_token_indices
=
[]
start_idx
=
0
start_idx
=
0
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
_
=
seq_group
sampling_type
=
sampling_params
.
sampling_type
if
i
<
input_metadata
.
num_prompts
:
if
i
<
input_metadata
.
num_prompts
:
assert
len
(
seq_ids
)
==
1
,
"Prompt input should have only one seq."
assert
len
(
seq_ids
)
==
1
,
"Prompt input should have only one seq."
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
last_token_indices
[
sampling_type
].
append
(
start_idx
+
prompt_len
-
last_token_indices
.
append
(
start_idx
+
prompt_len
-
1
)
1
)
start_idx
+=
prompt_len
start_idx
+=
prompt_len
else
:
else
:
num_seqs
=
len
(
seq_ids
)
num_seqs
=
len
(
seq_ids
)
last_token_indices
[
sampling_type
].
extend
(
last_token_indices
.
extend
(
range
(
start_idx
,
start_idx
+
num_seqs
))
range
(
start_idx
,
start_idx
+
num_seqs
))
start_idx
+=
num_seqs
start_idx
+=
num_seqs
all_last_token_indices
=
[]
last_token_indices
=
torch
.
tensor
(
last_token_indices
,
for
sampling_type
in
SamplingType
:
dtype
=
torch
.
long
,
all_last_token_indices
.
extend
(
last_token_indices
[
sampling_type
])
device
=
hidden_states
.
device
)
all_last_token_indices
=
torch
.
tensor
(
all_last_token_indices
,
return
hidden_states
.
index_select
(
0
,
last_token_indices
)
dtype
=
torch
.
long
,
device
=
hidden_states
.
device
)
return
hidden_states
.
index_select
(
0
,
all_last_token_indices
)
def
_get_penalties
(
def
_get_penalties
(
...
@@ -424,27 +418,26 @@ def _sample(
...
@@ -424,27 +418,26 @@ def _sample(
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
category_num_tokens
=
{
t
:
0
for
t
in
SamplingType
}
start_idx
=
0
categorized_seq_ids
=
{
t
:
[]
for
t
in
SamplingType
}
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
num_seqs
=
len
(
seq_ids
)
num_seqs
=
len
(
seq_ids
)
category_num_tokens
[
sampling_type
]
+=
num_seqs
categorized_seq_ids
[
sampling_type
].
extend
(
range
(
start_idx
,
start_idx
+
num_seqs
))
start_idx
+=
num_seqs
seq_outputs_dict
:
Dict
[
int
,
List
[
SequenceOutputs
]]
=
{}
seq_outputs_dict
:
Dict
[
int
,
List
[
SequenceOutputs
]]
=
{}
category_start_idx
=
0
for
sampling_type
in
SamplingType
:
for
sampling_type
in
SamplingType
:
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
input_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
seq_groups
=
[
input_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
input_metadata
.
num_prompts
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
input_metadata
.
num_prompts
for
i
in
seq_group_ids
]
num_tokens
=
categor
y_num_token
s
[
sampling_type
]
num_tokens
=
len
(
categor
ized_seq_id
s
[
sampling_type
]
)
if
num_tokens
==
0
:
if
num_tokens
==
0
:
continue
continue
category_logprobs
=
logprobs
[
category_start_idx
:
category_start_idx
+
category_logprobs
=
logprobs
[
categorized_seq_ids
[
sampling_type
]]
num_tokens
]
category_probs
=
probs
[
categorized_seq_ids
[
sampling_type
]]
category_probs
=
probs
[
category_start_idx
:
category_start_idx
+
num_tokens
]
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
category_logprobs
)
sample_results
=
_greedy_sample
(
seq_groups
,
category_logprobs
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
elif
sampling_type
==
SamplingType
.
RANDOM
:
...
@@ -497,6 +490,5 @@ def _sample(
...
@@ -497,6 +490,5 @@ def _sample(
sample_idx
+=
num_parent_seqs
sample_idx
+=
num_parent_seqs
result_idx
+=
num_results
result_idx
+=
num_results
assert
sample_idx
==
num_tokens
assert
sample_idx
==
num_tokens
category_start_idx
+=
num_tokens
return
[
seq_outputs_dict
[
i
]
for
i
in
range
(
len
(
input_metadata
.
seq_groups
))]
return
[
seq_outputs_dict
[
i
]
for
i
in
range
(
len
(
input_metadata
.
seq_groups
))]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment