Commit 6f6ea0a8 authored by 王敏's avatar 王敏
Browse files

[perf]优化异步调度+并行解码 step之间的空泡,实现kernel提前下发

parent 319506a5
...@@ -2188,28 +2188,55 @@ class GPUModelRunner( ...@@ -2188,28 +2188,55 @@ class GPUModelRunner(
) )
# [0, 1, 2, 5, 6, 9] # [0, 1, 2, 5, 6, 9]
target_logits_indices += arange target_logits_indices += arange
draft_token_indices = target_logits_indices + 1
# TODO: Optimize the CPU -> GPU copy. # TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( # cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True # self.device, non_blocking=True
) # )
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to( # cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
self.device, non_blocking=True # self.device, non_blocking=True
) # )
logits_indices = torch.from_numpy(logits_indices).to( # logits_indices = torch.from_numpy(logits_indices).to(
self.device, non_blocking=True # self.device, non_blocking=True
) # )
target_logits_indices = torch.from_numpy(target_logits_indices).to( # target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True # self.device, non_blocking=True
) # )
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( # bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True # self.device, non_blocking=True
) # )
# # Compute the draft token ids.
# # draft_token_indices: [ 1, 2, 3, 105, 106, 208]
# draft_token_ids = self.input_ids.gpu[logits_indices]
# draft_token_ids = draft_token_ids[target_logits_indices + 1]
# Optimize the H2D in the process of creating spec decode metadata
fused_meta_data = cu_num_draft_tokens.tolist() + cu_num_sampled_tokens.tolist()\
+ logits_indices.tolist() + target_logits_indices.tolist() + bonus_logits_indices.tolist()\
+ draft_token_indices.tolist()
fused_meta_data_len = np.array([len(cu_num_draft_tokens), len(cu_num_sampled_tokens),\
len(logits_indices), len(target_logits_indices),\
len(bonus_logits_indices), len(draft_token_indices)], dtype=np.int32)
cu_fused_meta_data_len = np.cumsum(fused_meta_data_len, dtype=np.int32)
fused_meta_data = torch.tensor(
fused_meta_data, dtype=torch.int32, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
cu_num_draft_tokens = fused_meta_data[:cu_fused_meta_data_len[0]]
cu_num_sampled_tokens = fused_meta_data[cu_fused_meta_data_len[0]:cu_fused_meta_data_len[1]]
logits_indices = fused_meta_data[cu_fused_meta_data_len[1]:cu_fused_meta_data_len[2]]
target_logits_indices = fused_meta_data[cu_fused_meta_data_len[2]:cu_fused_meta_data_len[3]]
bonus_logits_indices = fused_meta_data[cu_fused_meta_data_len[3]:cu_fused_meta_data_len[4]]
draft_token_indices = fused_meta_data[cu_fused_meta_data_len[4]:cu_fused_meta_data_len[5]]
# Compute the draft token ids. # Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208] # draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids.gpu[logits_indices] draft_token_ids = self.input_ids.gpu[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1] draft_token_ids = draft_token_ids[draft_token_indices]
return SpecDecodeMetadata( return SpecDecodeMetadata(
draft_token_ids=draft_token_ids, draft_token_ids=draft_token_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment