Commit 08c2298a authored by guanyu1's avatar guanyu1
Browse files

sampler修改

parent 9bd32639
...@@ -69,7 +69,14 @@ class SampleResultArgsType: ...@@ -69,7 +69,14 @@ class SampleResultArgsType:
sampling_metadata: SamplingMetadata sampling_metadata: SamplingMetadata
greedy_samples: Optional[torch.Tensor] greedy_samples: Optional[torch.Tensor]
beam_search_logprobs: Optional[torch.Tensor] beam_search_logprobs: Optional[torch.Tensor]
# Implemented by guanyu
@dataclass
class SampleDeviceToDevices:
num_parent_seq: torch.Tensor=None
seq_id:torch.Tensor=None
random_samples:torch.Tensor=None
sample_idx:int=None
d2d_data=SampleDeviceToDevices()
# Union of non-deferred (single-step scheduling) # Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling) # vs deferred (multi-step scheduling)
...@@ -496,6 +503,7 @@ def _random_sample( ...@@ -496,6 +503,7 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
# Find the maximum n value of the prompt phase requests. # Find the maximum n value of the prompt phase requests.
#random_samples = random_samples.cpu()删除,取消gpu->cpu之间的同步
random_samples = random_samples.cpu() random_samples = random_samples.cpu()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
...@@ -508,6 +516,7 @@ def _random_sample( ...@@ -508,6 +516,7 @@ def _random_sample(
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
d2d_data.num_parent_seq = num_parent_seqs
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
parent_ids = [0] * sampling_params.n parent_ids = [0] * sampling_params.n
...@@ -520,6 +529,7 @@ def _random_sample( ...@@ -520,6 +529,7 @@ def _random_sample(
num_parent_seqs, 0].tolist() num_parent_seqs, 0].tolist()
results.append((next_token_ids, parent_ids)) results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
d2d_data.sample_idx=sample_idx
return results return results
...@@ -697,6 +707,7 @@ def get_pythonized_sample_results( ...@@ -697,6 +707,7 @@ def get_pythonized_sample_results(
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples) sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
d2d_data.random_samples=multinomial_samples[sampling_type]#记录random_samples的数据
sample_results = _random_sample(seq_groups, sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type]) multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
...@@ -733,9 +744,13 @@ def _sample_with_torch( ...@@ -733,9 +744,13 @@ def _sample_with_torch(
categorized_seq_group_ids: Dict[SamplingType, List[int]] = { categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
t: [] t: []
for t in SamplingType for t in SamplingType
} }#初始化各种结果存储容器然后按照类型分类
print(f'sampling_metadata.seq_groups的长度:{len(sampling_metadata.seq_groups)}')
# 初始化一个tensor张量用于保存seq_id,初始值为-1
d2d_data.seq_id=torch.zeros(len(sampling_metadata.seq_groups),1)-1
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
d2d_data.seq_id[i]=seq_group.seq_ids[0]#将 i对应的seq_id存储到d2d_data.seq_id中
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment