Unverified Commit e23e280e authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

Add support for topk metadata transferring for PD (#10616)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent 51f7c6bd
...@@ -614,12 +614,16 @@ class DecodeTransferQueue: ...@@ -614,12 +614,16 @@ class DecodeTransferQueue:
output_token_logprobs_idx, output_token_logprobs_idx,
output_top_logprobs_val, output_top_logprobs_val,
output_top_logprobs_idx, output_top_logprobs_idx,
output_topk_p,
output_topk_index,
output_hidden_states, output_hidden_states,
) = self.metadata_buffers.get_buf(idx) ) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item()) decode_req.req.output_ids.append(output_id[0].item())
decode_req.req.cached_tokens = cached_tokens[0].item() decode_req.req.cached_tokens = cached_tokens[0].item()
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
decode_req.req.output_topk_p = output_topk_p
decode_req.req.output_topk_index = output_topk_index
decode_req.req.hidden_states_tensor = output_hidden_states decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob: if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append( decode_req.req.output_token_logprobs_val.append(
......
...@@ -125,25 +125,33 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -125,25 +125,33 @@ class ScheduleBatchDisaggregationDecodeMixin:
req.grammar.finished = req.finished() req.grammar.finished = req.finished()
self.output_ids = torch.tensor(self.output_ids, device=self.device) self.output_ids = torch.tensor(self.output_ids, device=self.device)
# Simulate the eagle run. We add mock data to hidden states for the # Simulate the eagle run.
# ease of implementation now meaning the first token will have acc rate if self.spec_algorithm.is_eagle():
# of 0.
if not self.spec_algorithm.is_none():
b = len(self.reqs) b = len(self.reqs)
topk_p = torch.arange( topk = server_args.speculative_eagle_topk
b * server_args.speculative_eagle_topk, topk_p = torch.stack(
0, [
-1, torch.as_tensor(
req.output_topk_p[:topk],
device=self.device, device=self.device,
dtype=torch.float32, dtype=torch.float32,
) )
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk) for req in self.reqs
topk_p /= b * server_args.speculative_eagle_topk ],
topk_index = torch.arange( dim=0,
b * server_args.speculative_eagle_topk, device=self.device )
topk_index = torch.stack(
[
torch.as_tensor(
req.output_topk_index[:topk],
device=self.device,
dtype=torch.int64,
)
for req in self.reqs
],
dim=0,
) )
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
hidden_states_list = [req.hidden_states_tensor for req in self.reqs] hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device) hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
......
...@@ -421,6 +421,8 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -421,6 +421,8 @@ class SchedulerDisaggregationPrefillMixin:
last_hidden_index = ( last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1 hidden_state_offset + extend_input_len_per_req[i] - 1
) )
req.output_topk_p = batch.spec_info.topk_p[i]
req.output_topk_index = batch.spec_info.topk_index[i]
if self.spec_algorithm.is_eagle3(): if self.spec_algorithm.is_eagle3():
req.hidden_states_tensor = ( req.hidden_states_tensor = (
batch.spec_info.hidden_states[i].cpu().clone() batch.spec_info.hidden_states[i].cpu().clone()
......
...@@ -85,7 +85,7 @@ class MetadataBuffers: ...@@ -85,7 +85,7 @@ class MetadataBuffers:
self, self,
size: int, size: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, hidden_states_dtype: torch.dtype,
max_top_logprobs_num: int = 128, max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None, custom_mem_pool: torch.cuda.MemPool = None,
): ):
...@@ -122,8 +122,15 @@ class MetadataBuffers: ...@@ -122,8 +122,15 @@ class MetadataBuffers:
self.output_top_logprobs_idx = torch.zeros( self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device (size, max_top_logprobs_num), dtype=torch.int32, device=device
) )
# For PD + spec decode
self.output_topk_p = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
self.output_topk_index = torch.zeros(
(size, 16), dtype=torch.int64, device=device
)
self.output_hidden_states = torch.zeros( self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=dtype, device=device (size, hidden_size), dtype=hidden_states_dtype, device=device
) )
def get_buf_infos(self): def get_buf_infos(self):
...@@ -134,6 +141,8 @@ class MetadataBuffers: ...@@ -134,6 +141,8 @@ class MetadataBuffers:
self.output_token_logprobs_idx.data_ptr(), self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(), self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(), self.output_top_logprobs_idx.data_ptr(),
self.output_topk_p.data_ptr(),
self.output_topk_index.data_ptr(),
self.output_hidden_states.data_ptr(), self.output_hidden_states.data_ptr(),
] ]
data_lens = [ data_lens = [
...@@ -143,6 +152,8 @@ class MetadataBuffers: ...@@ -143,6 +152,8 @@ class MetadataBuffers:
self.output_token_logprobs_idx.nbytes, self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes, self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes, self.output_top_logprobs_idx.nbytes,
self.output_topk_p.nbytes,
self.output_topk_index.nbytes,
self.output_hidden_states.nbytes, self.output_hidden_states.nbytes,
] ]
item_lens = [ item_lens = [
...@@ -152,6 +163,8 @@ class MetadataBuffers: ...@@ -152,6 +163,8 @@ class MetadataBuffers:
self.output_token_logprobs_idx[0].nbytes, self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes, self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes, self.output_top_logprobs_idx[0].nbytes,
self.output_topk_p[0].nbytes,
self.output_topk_index[0].nbytes,
self.output_hidden_states[0].nbytes, self.output_hidden_states[0].nbytes,
] ]
return ptrs, data_lens, item_lens return ptrs, data_lens, item_lens
...@@ -164,6 +177,8 @@ class MetadataBuffers: ...@@ -164,6 +177,8 @@ class MetadataBuffers:
self.output_token_logprobs_idx[idx], self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx], self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx], self.output_top_logprobs_idx[idx],
self.output_topk_p[idx],
self.output_topk_index[idx],
self.output_hidden_states[idx], self.output_hidden_states[idx],
) )
...@@ -193,8 +208,17 @@ class MetadataBuffers: ...@@ -193,8 +208,17 @@ class MetadataBuffers:
] = torch.tensor( ] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
) )
# for PD + spec decode # For PD + spec decode
if req.hidden_states_tensor is not None: if req.hidden_states_tensor is not None:
# speculative_eagle_topk should not be greater than 16 currently
topk = req.output_topk_p.size(0)
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
req.output_topk_p
)
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
req.output_topk_index
)
self.output_hidden_states[req.metadata_buffer_index].copy_( self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor req.hidden_states_tensor
) )
......
...@@ -607,6 +607,8 @@ class Req: ...@@ -607,6 +607,8 @@ class Req:
) = None ) = None
self.hidden_states: List[List[float]] = [] self.hidden_states: List[List[float]] = []
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
self.output_topk_p = None
self.output_topk_index = None
# Embedding (return values) # Embedding (return values)
self.embedding = None self.embedding = None
......
...@@ -806,7 +806,7 @@ class Scheduler( ...@@ -806,7 +806,7 @@ class Scheduler(
self.disagg_metadata_buffers = MetadataBuffers( self.disagg_metadata_buffers = MetadataBuffers(
buffer_size, buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size, hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype, hidden_states_dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
) )
...@@ -855,7 +855,7 @@ class Scheduler( ...@@ -855,7 +855,7 @@ class Scheduler(
self.disagg_metadata_buffers = MetadataBuffers( self.disagg_metadata_buffers = MetadataBuffers(
buffer_size, buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size, hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype, hidden_states_dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment