"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "d3a416cd99d0427cfbce8e997b058a3fbb78b37e"
Unverified Commit 4f838c09 authored by Atream's avatar Atream Committed by GitHub
Browse files

[PD] Transfer hidden states for mtp when disaggregation (#7242)

parent d20a073b
...@@ -541,6 +541,7 @@ class DecodeTransferQueue: ...@@ -541,6 +541,7 @@ class DecodeTransferQueue:
self.metadata_buffers = metadata_buffers self.metadata_buffers = metadata_buffers
self.scheduler = scheduler self.scheduler = scheduler
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.spec_algorithm = scheduler.spec_algorithm
def add(self, decode_req: DecodeRequest) -> None: def add(self, decode_req: DecodeRequest) -> None:
self.queue.append(decode_req) self.queue.append(decode_req)
...@@ -582,6 +583,7 @@ class DecodeTransferQueue: ...@@ -582,6 +583,7 @@ class DecodeTransferQueue:
idx = decode_req.metadata_buffer_index idx = decode_req.metadata_buffer_index
( (
output_id, output_id,
output_hidden_states,
output_token_logprobs_val, output_token_logprobs_val,
output_token_logprobs_idx, output_token_logprobs_idx,
output_top_logprobs_val, output_top_logprobs_val,
...@@ -589,7 +591,8 @@ class DecodeTransferQueue: ...@@ -589,7 +591,8 @@ class DecodeTransferQueue:
) = self.metadata_buffers.get_buf(idx) ) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item()) decode_req.req.output_ids.append(output_id[0].item())
if not self.spec_algorithm.is_none():
decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob: if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append( decode_req.req.output_token_logprobs_val.append(
output_token_logprobs_val[0].item() output_token_logprobs_val[0].item()
......
...@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
) )
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk) topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
# local import to avoid circular import # local import to avoid circular import
from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.speculative.eagle_utils import EagleDraftInput
spec_info = EagleDraftInput( spec_info = EagleDraftInput(
topk_p=topk_p, topk_p=topk_p,
topk_index=topk_index, topk_index=topk_index,
hidden_states=torch.ones( hidden_states=hidden_states,
(b, model_config.hidden_size), device=self.device
),
verified_id=self.output_ids, verified_id=self.output_ids,
) )
spec_info.prepare_for_extend(self) spec_info.prepare_for_extend(self)
......
...@@ -393,6 +393,8 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -393,6 +393,8 @@ class SchedulerDisaggregationPrefillMixin:
logits_output.input_token_logprobs = tuple( logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist() logits_output.input_token_logprobs.tolist()
) )
hidden_state_offset = 0
for i, (req, next_token_id) in enumerate( for i, (req, next_token_id) in enumerate(
zip(batch.reqs, next_token_ids, strict=True) zip(batch.reqs, next_token_ids, strict=True)
): ):
...@@ -402,6 +404,16 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -402,6 +404,16 @@ class SchedulerDisaggregationPrefillMixin:
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock self.tree_cache.cache_unfinished_req(req) # update the tree and lock
self.disagg_prefill_inflight_queue.append(req) self.disagg_prefill_inflight_queue.append(req)
if logits_output.hidden_states is not None:
last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1
)
req.hidden_states_tensor = (
logits_output.hidden_states[last_hidden_index].cpu().clone()
)
hidden_state_offset += extend_input_len_per_req[i]
else:
req.hidden_states_tensor = None
if req.return_logprob: if req.return_logprob:
assert extend_logprob_start_len_per_req is not None assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None assert extend_input_len_per_req is not None
......
...@@ -88,6 +88,8 @@ class MetadataBuffers: ...@@ -88,6 +88,8 @@ class MetadataBuffers:
def __init__( def __init__(
self, self,
size: int, size: int,
hidden_size: int,
dtype: torch.dtype,
max_top_logprobs_num: int = 128, max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None, custom_mem_pool: torch.cuda.MemPool = None,
): ):
...@@ -104,6 +106,10 @@ class MetadataBuffers: ...@@ -104,6 +106,10 @@ class MetadataBuffers:
# We transfer the metadata of first output token to decode # We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device) self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=dtype, device=device
)
self.output_token_logprobs_val = torch.zeros( self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device (size, 16), dtype=torch.float32, device=device
) )
...@@ -120,6 +126,7 @@ class MetadataBuffers: ...@@ -120,6 +126,7 @@ class MetadataBuffers:
def get_buf_infos(self): def get_buf_infos(self):
ptrs = [ ptrs = [
self.output_ids.data_ptr(), self.output_ids.data_ptr(),
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
self.output_token_logprobs_val.data_ptr(), self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(), self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(), self.output_top_logprobs_val.data_ptr(),
...@@ -127,6 +134,7 @@ class MetadataBuffers: ...@@ -127,6 +134,7 @@ class MetadataBuffers:
] ]
data_lens = [ data_lens = [
self.output_ids.nbytes, self.output_ids.nbytes,
self.output_hidden_states.nbytes,
self.output_token_logprobs_val.nbytes, self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes, self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes, self.output_top_logprobs_val.nbytes,
...@@ -134,6 +142,7 @@ class MetadataBuffers: ...@@ -134,6 +142,7 @@ class MetadataBuffers:
] ]
item_lens = [ item_lens = [
self.output_ids[0].nbytes, self.output_ids[0].nbytes,
self.output_hidden_states[0].nbytes,
self.output_token_logprobs_val[0].nbytes, self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes, self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes, self.output_top_logprobs_val[0].nbytes,
...@@ -144,6 +153,7 @@ class MetadataBuffers: ...@@ -144,6 +153,7 @@ class MetadataBuffers:
def get_buf(self, idx: int): def get_buf(self, idx: int):
return ( return (
self.output_ids[idx], self.output_ids[idx],
self.output_hidden_states[idx],
self.output_token_logprobs_val[idx], self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx], self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx], self.output_top_logprobs_val[idx],
...@@ -153,6 +163,10 @@ class MetadataBuffers: ...@@ -153,6 +163,10 @@ class MetadataBuffers:
def set_buf(self, req: Req): def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
if req.hidden_states_tensor is not None:
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
if req.return_logprob: if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
......
...@@ -584,6 +584,7 @@ class Req: ...@@ -584,6 +584,7 @@ class Req:
self.output_token_ids_logprobs_idx self.output_token_ids_logprobs_idx
) = None ) = None
self.hidden_states: List[List[float]] = [] self.hidden_states: List[List[float]] = []
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
# Embedding (return values) # Embedding (return values)
self.embedding = None self.embedding = None
......
...@@ -627,6 +627,8 @@ class Scheduler( ...@@ -627,6 +627,8 @@ class Scheduler(
) )
self.disagg_metadata_buffers = MetadataBuffers( self.disagg_metadata_buffers = MetadataBuffers(
buffer_size, buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
) )
...@@ -677,6 +679,8 @@ class Scheduler( ...@@ -677,6 +679,8 @@ class Scheduler(
) )
self.disagg_metadata_buffers = MetadataBuffers( self.disagg_metadata_buffers = MetadataBuffers(
buffer_size, buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
) )
...@@ -1681,13 +1685,15 @@ class Scheduler( ...@@ -1681,13 +1685,15 @@ class Scheduler(
# These 2 values are needed for processing the output, but the values can be # These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that # modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing. # we can use the correct values in output processing.
if batch.return_logprob: if batch.return_logprob or self.spec_algorithm.is_eagle():
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs] extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
else:
extend_input_len_per_req = None
if batch.return_logprob:
extend_logprob_start_len_per_req = [ extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs req.extend_logprob_start_len for req in batch.reqs
] ]
else: else:
extend_input_len_per_req = None
extend_logprob_start_len_per_req = None extend_logprob_start_len_per_req = None
ret = GenerationBatchResult( ret = GenerationBatchResult(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment