Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
08c2298a
Commit
08c2298a
authored
Mar 17, 2025
by
guanyu1
Browse files
sampler修改
parent
9bd32639
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
2 deletions
+17
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+17
-2
No files found.
vllm/model_executor/layers/sampler.py
View file @
08c2298a
...
...
@@ -69,7 +69,14 @@ class SampleResultArgsType:
sampling_metadata
:
SamplingMetadata
greedy_samples
:
Optional
[
torch
.
Tensor
]
beam_search_logprobs
:
Optional
[
torch
.
Tensor
]
# Implemented by guanyu
@
dataclass
class
SampleDeviceToDevices
:
num_parent_seq
:
torch
.
Tensor
=
None
seq_id
:
torch
.
Tensor
=
None
random_samples
:
torch
.
Tensor
=
None
sample_idx
:
int
=
None
d2d_data
=
SampleDeviceToDevices
()
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
...
...
@@ -496,6 +503,7 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
#random_samples = random_samples.cpu()删除,取消gpu->cpu之间的同步
random_samples
=
random_samples
.
cpu
()
sample_idx
=
0
results
:
SampleResultType
=
[]
...
...
@@ -508,6 +516,7 @@ def _random_sample(
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
num_parent_seqs
=
len
(
seq_ids
)
d2d_data
.
num_parent_seq
=
num_parent_seqs
if
is_prompt
:
# Prompt phase.
parent_ids
=
[
0
]
*
sampling_params
.
n
...
...
@@ -520,6 +529,7 @@ def _random_sample(
num_parent_seqs
,
0
].
tolist
()
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
d2d_data
.
sample_idx
=
sample_idx
return
results
...
...
@@ -697,6 +707,7 @@ def get_pythonized_sample_results(
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
d2d_data
.
random_samples
=
multinomial_samples
[
sampling_type
]
#记录random_samples的数据
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
elif
sampling_type
==
SamplingType
.
BEAM
:
...
...
@@ -733,9 +744,13 @@ def _sample_with_torch(
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
}
#初始化各种结果存储容器然后按照类型分类
print
(
f
'sampling_metadata.seq_groups的长度:
{
len
(
sampling_metadata
.
seq_groups
)
}
'
)
# 初始化一个tensor张量用于保存seq_id,初始值为-1
d2d_data
.
seq_id
=
torch
.
zeros
(
len
(
sampling_metadata
.
seq_groups
),
1
)
-
1
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
d2d_data
.
seq_id
[
i
]
=
seq_group
.
seq_ids
[
0
]
#将 i对应的seq_id存储到d2d_data.seq_id中
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment