Commit 6f6ea0a8 authored by 王敏's avatar 王敏
Browse files

[perf]优化异步调度+并行解码 step之间的空泡,实现kernel提前下发

parent 319506a5
......@@ -2188,28 +2188,55 @@ class GPUModelRunner(
)
# [0, 1, 2, 5, 6, 9]
target_logits_indices += arange
draft_token_indices = target_logits_indices + 1
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True
)
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
self.device, non_blocking=True
)
logits_indices = torch.from_numpy(logits_indices).to(
self.device, non_blocking=True
)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True
)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True
)
# cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
# self.device, non_blocking=True
# )
# cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
# self.device, non_blocking=True
# )
# logits_indices = torch.from_numpy(logits_indices).to(
# self.device, non_blocking=True
# )
# target_logits_indices = torch.from_numpy(target_logits_indices).to(
# self.device, non_blocking=True
# )
# bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
# self.device, non_blocking=True
# )
# # Compute the draft token ids.
# # draft_token_indices: [ 1, 2, 3, 105, 106, 208]
# draft_token_ids = self.input_ids.gpu[logits_indices]
# draft_token_ids = draft_token_ids[target_logits_indices + 1]
# Optimize the H2D in the process of creating spec decode metadata
fused_meta_data = cu_num_draft_tokens.tolist() + cu_num_sampled_tokens.tolist()\
+ logits_indices.tolist() + target_logits_indices.tolist() + bonus_logits_indices.tolist()\
+ draft_token_indices.tolist()
fused_meta_data_len = np.array([len(cu_num_draft_tokens), len(cu_num_sampled_tokens),\
len(logits_indices), len(target_logits_indices),\
len(bonus_logits_indices), len(draft_token_indices)], dtype=np.int32)
cu_fused_meta_data_len = np.cumsum(fused_meta_data_len, dtype=np.int32)
fused_meta_data = torch.tensor(
fused_meta_data, dtype=torch.int32, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
cu_num_draft_tokens = fused_meta_data[:cu_fused_meta_data_len[0]]
cu_num_sampled_tokens = fused_meta_data[cu_fused_meta_data_len[0]:cu_fused_meta_data_len[1]]
logits_indices = fused_meta_data[cu_fused_meta_data_len[1]:cu_fused_meta_data_len[2]]
target_logits_indices = fused_meta_data[cu_fused_meta_data_len[2]:cu_fused_meta_data_len[3]]
bonus_logits_indices = fused_meta_data[cu_fused_meta_data_len[3]:cu_fused_meta_data_len[4]]
draft_token_indices = fused_meta_data[cu_fused_meta_data_len[4]:cu_fused_meta_data_len[5]]
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids.gpu[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
draft_token_ids = draft_token_ids[draft_token_indices]
return SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment