Unverified Commit 7b9a174a authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD][Spec] Fix hidden state transfer for spec decode (#7516)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 03c039c4
......@@ -579,11 +579,11 @@ class DecodeTransferQueue:
idx = decode_req.metadata_buffer_index
(
output_id,
output_hidden_states,
output_token_logprobs_val,
output_token_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
output_hidden_states,
) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item())
......
......@@ -291,15 +291,21 @@ class MooncakeKVManager(BaseKVManager):
dst_aux_ptrs: list[int],
dst_aux_index: int,
):
aux_item_len = self.kv_args.aux_item_lens[0]
prefill_aux_addr = (
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
src_addr_list = []
dst_addr_list = []
length_list = []
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
prefill_aux_item_lens = self.kv_args.aux_item_lens
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
length = prefill_aux_item_lens[i]
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
src_addr_list.append(src_addr)
dst_addr_list.append(dst_addr)
length_list.append(length)
return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list
)
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
status = self.engine.transfer_sync(
mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
)
return status
def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int
......
......@@ -107,9 +107,6 @@ class MetadataBuffers:
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=dtype, device=device
)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
......@@ -122,51 +119,50 @@ class MetadataBuffers:
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device
)
self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=dtype, device=device
)
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
self.output_hidden_states.data_ptr(),
]
data_lens = [
self.output_ids.nbytes,
self.output_hidden_states.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes,
self.output_hidden_states.nbytes,
]
item_lens = [
self.output_ids[0].nbytes,
self.output_hidden_states[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes,
self.output_hidden_states[0].nbytes,
]
return ptrs, data_lens, item_lens
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.output_hidden_states[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx],
self.output_hidden_states[idx],
)
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
if req.hidden_states_tensor is not None:
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
......@@ -189,6 +185,11 @@ class MetadataBuffers:
] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
)
# for PD + spec decode
if req.hidden_states_tensor is not None:
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
#########################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment