Unverified Commit 7b9a174a authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD][Spec] Fix hidden state transfer for spec decode (#7516)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 03c039c4
...@@ -579,11 +579,11 @@ class DecodeTransferQueue: ...@@ -579,11 +579,11 @@ class DecodeTransferQueue:
idx = decode_req.metadata_buffer_index idx = decode_req.metadata_buffer_index
( (
output_id, output_id,
output_hidden_states,
output_token_logprobs_val, output_token_logprobs_val,
output_token_logprobs_idx, output_token_logprobs_idx,
output_top_logprobs_val, output_top_logprobs_val,
output_top_logprobs_idx, output_top_logprobs_idx,
output_hidden_states,
) = self.metadata_buffers.get_buf(idx) ) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item()) decode_req.req.output_ids.append(output_id[0].item())
......
...@@ -291,15 +291,21 @@ class MooncakeKVManager(BaseKVManager): ...@@ -291,15 +291,21 @@ class MooncakeKVManager(BaseKVManager):
dst_aux_ptrs: list[int], dst_aux_ptrs: list[int],
dst_aux_index: int, dst_aux_index: int,
): ):
aux_item_len = self.kv_args.aux_item_lens[0] src_addr_list = []
prefill_aux_addr = ( dst_addr_list = []
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len length_list = []
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
prefill_aux_item_lens = self.kv_args.aux_item_lens
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
length = prefill_aux_item_lens[i]
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
src_addr_list.append(src_addr)
dst_addr_list.append(dst_addr)
length_list.append(length)
return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list
) )
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
status = self.engine.transfer_sync(
mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
)
return status
def sync_status_to_decode_endpoint( def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int self, remote: str, dst_port: int, room: int, status: int
......
...@@ -107,9 +107,6 @@ class MetadataBuffers: ...@@ -107,9 +107,6 @@ class MetadataBuffers:
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device) self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=dtype, device=device
)
self.output_token_logprobs_val = torch.zeros( self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device (size, 16), dtype=torch.float32, device=device
) )
...@@ -122,51 +119,50 @@ class MetadataBuffers: ...@@ -122,51 +119,50 @@ class MetadataBuffers:
self.output_top_logprobs_idx = torch.zeros( self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device (size, max_top_logprobs_num), dtype=torch.int32, device=device
) )
self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=dtype, device=device
)
def get_buf_infos(self): def get_buf_infos(self):
ptrs = [ ptrs = [
self.output_ids.data_ptr(), self.output_ids.data_ptr(),
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
self.output_token_logprobs_val.data_ptr(), self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(), self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(), self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(), self.output_top_logprobs_idx.data_ptr(),
self.output_hidden_states.data_ptr(),
] ]
data_lens = [ data_lens = [
self.output_ids.nbytes, self.output_ids.nbytes,
self.output_hidden_states.nbytes,
self.output_token_logprobs_val.nbytes, self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes, self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes, self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes, self.output_top_logprobs_idx.nbytes,
self.output_hidden_states.nbytes,
] ]
item_lens = [ item_lens = [
self.output_ids[0].nbytes, self.output_ids[0].nbytes,
self.output_hidden_states[0].nbytes,
self.output_token_logprobs_val[0].nbytes, self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes, self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes, self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes, self.output_top_logprobs_idx[0].nbytes,
self.output_hidden_states[0].nbytes,
] ]
return ptrs, data_lens, item_lens return ptrs, data_lens, item_lens
def get_buf(self, idx: int): def get_buf(self, idx: int):
return ( return (
self.output_ids[idx], self.output_ids[idx],
self.output_hidden_states[idx],
self.output_token_logprobs_val[idx], self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx], self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx], self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx], self.output_top_logprobs_idx[idx],
self.output_hidden_states[idx],
) )
def set_buf(self, req: Req): def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
if req.hidden_states_tensor is not None:
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
if req.return_logprob: if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
...@@ -189,6 +185,11 @@ class MetadataBuffers: ...@@ -189,6 +185,11 @@ class MetadataBuffers:
] = torch.tensor( ] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
) )
# for PD + spec decode
if req.hidden_states_tensor is not None:
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
######################### #########################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment