Commit 08c2298a authored by guanyu1's avatar guanyu1
Browse files

sampler修改

parent 9bd32639
......@@ -69,7 +69,14 @@ class SampleResultArgsType:
sampling_metadata: SamplingMetadata
greedy_samples: Optional[torch.Tensor]
beam_search_logprobs: Optional[torch.Tensor]
# Implemented by guanyu
@dataclass
class SampleDeviceToDevices:
num_parent_seq: torch.Tensor=None
seq_id:torch.Tensor=None
random_samples:torch.Tensor=None
sample_idx:int=None
d2d_data=SampleDeviceToDevices()
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
......@@ -496,6 +503,7 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
#random_samples = random_samples.cpu()删除,取消gpu->cpu之间的同步
random_samples = random_samples.cpu()
sample_idx = 0
results: SampleResultType = []
......@@ -508,6 +516,7 @@ def _random_sample(
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids)
d2d_data.num_parent_seq = num_parent_seqs
if is_prompt:
# Prompt phase.
parent_ids = [0] * sampling_params.n
......@@ -520,6 +529,7 @@ def _random_sample(
num_parent_seqs, 0].tolist()
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
d2d_data.sample_idx=sample_idx
return results
......@@ -697,6 +707,7 @@ def get_pythonized_sample_results(
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
d2d_data.random_samples=multinomial_samples[sampling_type]#记录random_samples的数据
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
......@@ -733,9 +744,13 @@ def _sample_with_torch(
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
t: []
for t in SamplingType
}
}#初始化各种结果存储容器然后按照类型分类
print(f'sampling_metadata.seq_groups的长度:{len(sampling_metadata.seq_groups)}')
# 初始化一个tensor张量用于保存seq_id,初始值为-1
d2d_data.seq_id=torch.zeros(len(sampling_metadata.seq_groups),1)-1
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
d2d_data.seq_id[i]=seq_group.seq_ids[0]#将 i对应的seq_id存储到d2d_data.seq_id中
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment