"tests/pytorch/vscode:/vscode.git/clone" did not exist on "7c059e86c736f19b63d14df789e546ce479dfdfa"
Unverified Commit 4f838c09 authored by Atream's avatar Atream Committed by GitHub
Browse files

[PD] Transfer hidden states for mtp when disaggregation (#7242)

parent d20a073b
......@@ -541,6 +541,7 @@ class DecodeTransferQueue:
self.metadata_buffers = metadata_buffers
self.scheduler = scheduler
self.tree_cache = tree_cache
self.spec_algorithm = scheduler.spec_algorithm
def add(self, decode_req: DecodeRequest) -> None:
self.queue.append(decode_req)
......@@ -582,6 +583,7 @@ class DecodeTransferQueue:
idx = decode_req.metadata_buffer_index
(
output_id,
output_hidden_states,
output_token_logprobs_val,
output_token_logprobs_idx,
output_top_logprobs_val,
......@@ -589,7 +591,8 @@ class DecodeTransferQueue:
) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item())
if not self.spec_algorithm.is_none():
decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append(
output_token_logprobs_val[0].item()
......
......@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
)
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
# local import to avoid circular import
from sglang.srt.speculative.eagle_utils import EagleDraftInput
spec_info = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=torch.ones(
(b, model_config.hidden_size), device=self.device
),
hidden_states=hidden_states,
verified_id=self.output_ids,
)
spec_info.prepare_for_extend(self)
......
......@@ -393,6 +393,8 @@ class SchedulerDisaggregationPrefillMixin:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
hidden_state_offset = 0
for i, (req, next_token_id) in enumerate(
zip(batch.reqs, next_token_ids, strict=True)
):
......@@ -402,6 +404,16 @@ class SchedulerDisaggregationPrefillMixin:
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
self.disagg_prefill_inflight_queue.append(req)
if logits_output.hidden_states is not None:
last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1
)
req.hidden_states_tensor = (
logits_output.hidden_states[last_hidden_index].cpu().clone()
)
hidden_state_offset += extend_input_len_per_req[i]
else:
req.hidden_states_tensor = None
if req.return_logprob:
assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None
......
......@@ -88,6 +88,8 @@ class MetadataBuffers:
def __init__(
self,
size: int,
hidden_size: int,
dtype: torch.dtype,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
):
......@@ -104,6 +106,10 @@ class MetadataBuffers:
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=dtype, device=device
)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
......@@ -120,6 +126,7 @@ class MetadataBuffers:
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
......@@ -127,6 +134,7 @@ class MetadataBuffers:
]
data_lens = [
self.output_ids.nbytes,
self.output_hidden_states.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
......@@ -134,6 +142,7 @@ class MetadataBuffers:
]
item_lens = [
self.output_ids[0].nbytes,
self.output_hidden_states[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
......@@ -144,6 +153,7 @@ class MetadataBuffers:
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.output_hidden_states[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
......@@ -153,6 +163,10 @@ class MetadataBuffers:
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
if req.hidden_states_tensor is not None:
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
......
......@@ -584,6 +584,7 @@ class Req:
self.output_token_ids_logprobs_idx
) = None
self.hidden_states: List[List[float]] = []
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
# Embedding (return values)
self.embedding = None
......
......@@ -627,6 +627,8 @@ class Scheduler(
)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
......@@ -677,6 +679,8 @@ class Scheduler(
)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
......@@ -1681,13 +1685,15 @@ class Scheduler(
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
if batch.return_logprob:
if batch.return_logprob or self.spec_algorithm.is_eagle():
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
else:
extend_input_len_per_req = None
if batch.return_logprob:
extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs
]
else:
extend_input_len_per_req = None
extend_logprob_start_len_per_req = None
ret = GenerationBatchResult(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment