Commit 587c919c authored by 王敏's avatar 王敏
Browse files

[Feat]初步实现PP+MTP

parent ca9ce18d
......@@ -1081,11 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
"""
expected_m = (
hidden_states.shape[0] * self.fused_experts.num_dispatchers * topk_ids.shape[1]
+ global_num_experts
) // global_num_experts
self.fused_experts.set_expected_m(expected_m)
if self.fused_experts.num_dispatchers is not None:
expected_m = (
hidden_states.shape[0] * self.fused_experts.num_dispatchers * topk_ids.shape[1]
+ global_num_experts
) // global_num_experts
self.fused_experts.set_expected_m(expected_m)
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
......
......@@ -341,6 +341,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys())
for name, loaded_weight in weights:
if "embed_tokens" in name:
for local_name in params_dict.keys():
if "embed_tokens" in local_name:
param = params_dict[local_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
break
if "rotary_emb.inv_freq" in name:
continue
......
......@@ -355,14 +355,14 @@ class Scheduler(SchedulerInterface):
# do not schedule another step for the same request while it still has
# output placeholders for PP.
# TODO: support PP + async scheduling without this limit
if self.use_pp:
if (envs.VLLM_USE_PP_BALANCE and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+ len(scheduled_running_reqs) >= max_batch_running):
break
if request.num_output_placeholders > 0:
req_index += 1
continue
# if self.use_pp:
# if (envs.VLLM_USE_PP_BALANCE and
# len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
# + len(scheduled_running_reqs) >= max_batch_running):
# break
# if request.num_output_placeholders > 0:
# req_index += 1
# continue
if (
request.num_output_placeholders > 0
......@@ -1211,9 +1211,9 @@ class Scheduler(SchedulerInterface):
# do not schedule another step for the same request while it still has
# output placeholders for PP.
# TODO: support PP + async scheduling without this limit
if self.use_pp and request.num_output_placeholders > 0:
req_index += 1
continue
# if self.use_pp and request.num_output_placeholders > 0:
# req_index += 1
# continue
if (
request.num_output_placeholders > 0
......@@ -1617,7 +1617,11 @@ class Scheduler(SchedulerInterface):
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
req_id = req.request_id
req_ids.append(req_id)
if self.use_pp:
#if self.use_pp:
# NOTE: In PP+async scheduling, we consume token ids via a direct GPU
# broadcast path (`input_batch.prev_sampled_token_ids`), so we can
# omit this payload.
if self.use_pp and not self.scheduler_config.async_scheduling:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
......
......@@ -966,12 +966,15 @@ class GPUModelRunner(
# that case we include the resumed_req_ids in the unscheduled set so
# that they get cleared from the persistent batch before being re-scheduled
# in the normal resumed request path.
#print(f"##################cached_req_ids:{cached_req_ids} scheduled_req_ids:{scheduled_req_ids} async scheduling:{self.use_async_scheduling}")
unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids)
# NOTE(woosuk): The persistent batch optimization assumes that
# consecutive batches contain mostly the same requests. If batches
# have low request overlap (e.g., alternating between two distinct
# sets of requests), this optimization becomes very inefficient.
for req_id in unscheduled_req_ids:
#print(f"#############################remove_request:{req_id}")
self.input_batch.remove_request(req_id)
reqs_to_add: list[CachedRequestState] = []
......@@ -1077,25 +1080,32 @@ class GPUModelRunner(
num_rejected = req_state.prev_num_draft_len - num_accepted
num_computed_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted)
#print(f"#############################req_id:{req_id} num_accepted:{num_accepted}")
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker.
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (
num_computed_tokens + len(new_token_ids) - req_state.num_tokens
)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:])
if not req_data.new_token_ids:
# Async scheduled PP: Sampled tokens propagated via GPU broadcast.
new_token_ids: list[int] = []
else:
# Non-async scheduling with PP: The scheduler sends
# sampled token ids back because there's no direct communication
# between the first-stage worker and the last-stage worker.
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (
num_computed_tokens + len(new_token_ids) - req_state.num_tokens
)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:]
)
elif num_output_tokens < len(req_state.output_token_ids):
# Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state.
......@@ -1431,6 +1441,8 @@ class GPUModelRunner(
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
#print(f"###############sampled_tokens_index_tensor:{sampled_tokens_index_tensor} prev_common_req_indices_tensor:{prev_common_req_indices_tensor} prev_sampled_token_ids:{self.input_batch.prev_sampled_token_ids}")
self.input_ids.gpu.scatter_(
dim=0,
index=sampled_tokens_index_tensor,
......@@ -3938,6 +3950,7 @@ class GPUModelRunner(
scheduler_output, clear_metadata=clear_kv_metadata
) as kv_connector_output,
):
#print(f"####################execute model input_ids:{input_ids.tolist()}")
model_output = self._model_forward(
input_ids=input_ids,
positions=positions,
......@@ -4028,7 +4041,17 @@ class GPUModelRunner(
self.kv_connector_output = None
if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used.
# receive sampled token ids from the last PP rank.
if self.use_async_scheduling and get_pp_group().world_size > 1:
if not self.speculative_config:
self._pp_receive_prev_sampled_token_ids_to_input_batch()
else:
self._draft_token_ids = None
self._draft_token_req_ids = None
self.input_batch.prev_sampled_token_ids = None
self._pp_receive_prev_sampled_token_ids_and_valid_sampled_tokens_count()
self._pp_receive_draft_token_ids()
if not kv_connector_output:
return None # type: ignore[return-value]
......@@ -4070,6 +4093,12 @@ class GPUModelRunner(
sampler_output.sampled_token_ids, scheduler_output
)
pp = get_pp_group()
if self.use_async_scheduling and pp.world_size > 1 and pp.is_last_rank and not self.speculative_config:
self._pp_broadcast_prev_sampled_token_ids(
sampler_output.sampled_token_ids
)
self._draft_token_ids = None
self._draft_token_req_ids = None
self.input_batch.prev_sampled_token_ids = None
......@@ -4090,6 +4119,11 @@ class GPUModelRunner(
)
self._copy_draft_token_ids_to_cpu(scheduler_output)
# broadcast draft_token_ids to non-last pp rank
if self.use_async_scheduling and pp.world_size > 1 and pp.is_last_rank:
self._pp_broadcast_draft_token_ids(self._draft_token_ids)
spec_config = self.speculative_config
propose_drafts_after_bookkeeping = False
if spec_config is not None:
......@@ -4207,6 +4241,129 @@ class GPUModelRunner(
return async_output
def _pp_broadcast_prev_sampled_token_ids(
self, sampled_token_ids: torch.Tensor
) -> None:
"""Broadcast sampled token ids (GPU) from last PP stage"""
pp = get_pp_group()
assert pp.is_last_rank
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
assert sampled_token_ids.dim() == 2 and sampled_token_ids.shape[-1] == 1, (
"PP+async expects sampled_token_ids to have shape [num_reqs, 1]"
)
torch.distributed.broadcast(
sampled_token_ids, src=pp.rank, group=pp.device_group
)
def _pp_receive_prev_sampled_token_ids_to_input_batch(self) -> None:
"""Receive sampled token ids broadcast from last PP stage"""
pp = get_pp_group()
assert not pp.is_last_rank
num_reqs = self.input_batch.num_reqs
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
recv = torch.empty((num_reqs, 1), dtype=torch.int32, device=self.device)
torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group)
self.input_batch.prev_sampled_token_ids = recv
# construct `prev_req_id_to_index` here so `_prepare_input_ids`
# can map req_id -> previous batch row
discard_req_indices = np.nonzero(self.discard_request_mask.np[:num_reqs])[0]
discard_req_indices_set = set(discard_req_indices)
prev_req_id_to_index: dict[str, int] = {}
for i, req_id in enumerate(self.input_batch.req_ids):
if i in discard_req_indices_set:
continue
prev_req_id_to_index[req_id] = i
# PP+async scheduling: advance per-request local cached output length by
# appending a placeholder (-1) token id.
if (req_state := self.requests.get(req_id)) is not None:
req_state.output_token_ids.append(-1)
self.input_batch.prev_req_id_to_index = prev_req_id_to_index
def _pp_broadcast_prev_sampled_token_ids_and_valid_sampled_tokens_count(
self, sampled_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
) -> None:
"""Broadcast sampled token ids (GPU) from last PP stage"""
pp = get_pp_group()
assert pp.is_last_rank
sampled_token_ids = sampled_token_ids.view(-1, 1)
valid_sampled_tokens_count = valid_sampled_tokens_count.view(-1, 1)
#print(f"##################pp broadcast sampled_token_ids:{sampled_token_ids.tolist()} valid_sampled_tokens_count:{valid_sampled_tokens_count.tolist()}")
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
assert sampled_token_ids.dim() == 2 and sampled_token_ids.shape[-1] == 1, (
"PP+async expects sampled_token_ids to have shape [num_reqs, 1]"
)
assert valid_sampled_tokens_count.dim() == 2 and sampled_token_ids.shape[-1] == 1, (
"PP+async expects valid_sampled_tokens_count to have shape [num_reqs, 1]"
)
data = torch.cat([sampled_token_ids, valid_sampled_tokens_count],dim=-1)
torch.distributed.broadcast(
data, src=pp.rank, group=pp.device_group
)
def _pp_receive_prev_sampled_token_ids_and_valid_sampled_tokens_count(self) -> None:
"""Receive sampled token ids broadcast from last PP stage"""
pp = get_pp_group()
assert not pp.is_last_rank
num_reqs = self.input_batch.num_reqs
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
recv = torch.empty((num_reqs, 2), dtype=torch.int32, device=self.device)
torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group)
prev_sampled_token_ids = recv[:, :1].squeeze(1)
valid_sampled_tokens_count = recv[:, -1]
self._copy_valid_sampled_token_count(
prev_sampled_token_ids, valid_sampled_tokens_count
)
#print(f"#############pp recv prev_sampled_token_ids:{prev_sampled_token_ids.tolist()} valid_sampled_tokens_count:{valid_sampled_tokens_count.tolist()}")
# construct `prev_req_id_to_index` here so `_prepare_input_ids`
# can map req_id -> previous batch row
discard_req_indices = np.nonzero(self.discard_request_mask.np[:num_reqs])[0]
discard_req_indices_set = set(discard_req_indices)
prev_req_id_to_index: dict[str, int] = {}
for i, req_id in enumerate(self.input_batch.req_ids):
if i in discard_req_indices_set:
continue
prev_req_id_to_index[req_id] = i
# PP+async scheduling: advance per-request local cached output length by
# appending a placeholder (-1*(self.num_spec_tokens + 1)) token id.
# if (req_state := self.requests.get(req_id)) is not None:
# #req_state.output_token_ids.append(-1)
# req_state.output_token_ids.extend([-1] * (self.num_spec_tokens + 1))
self.input_batch.prev_req_id_to_index = prev_req_id_to_index
def _pp_broadcast_draft_token_ids(
self, draft_token_ids: torch.Tensor
) -> None:
"""Broadcast sampled token ids (GPU) from last PP stage"""
pp = get_pp_group()
assert pp.is_last_rank
draft_token_ids = draft_token_ids.to(torch.int32)
# `draft_token_ids` is expected to have shape [num_reqs, num_spec_tokens].
assert draft_token_ids.dim() == 2 and draft_token_ids.shape[-1] == self.num_spec_tokens, (
"PP+async expects sampled_token_ids to have shape [num_reqs, num_spec_tokens]"
)
#print(f"####################broadcast draft_token_ids:{draft_token_ids}")
torch.distributed.broadcast(
draft_token_ids, src=pp.rank, group=pp.device_group
)
def _pp_receive_draft_token_ids(self) -> None:
"""Receive sampled token ids broadcast from last PP stage"""
pp = get_pp_group()
assert not pp.is_last_rank
num_reqs = self.input_batch.num_reqs
# `prev_sampled_token_ids` is expected to have shape [num_reqs, num_spec_tokens].
recv = torch.empty((num_reqs, self.num_spec_tokens), dtype=torch.int32, device=self.device)
torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group)
#print(f"####################pp recv draft_token_ids:{recv} num_spec_tokens:{self.num_spec_tokens}")
self._draft_token_ids = recv
def take_draft_token_ids(self) -> DraftTokenIds | None:
if not self.num_spec_tokens or not self._draft_token_req_ids:
return None
......@@ -4379,6 +4536,12 @@ class GPUModelRunner(
self.discard_request_mask.gpu,
)
)
# broadcast next_token_ids and valid_sampled_tokens_count to non-last pp rank
pp = get_pp_group()
if self.use_async_scheduling and pp.world_size > 1 and pp.is_last_rank:
self._pp_broadcast_prev_sampled_token_ids_and_valid_sampled_tokens_count(next_token_ids,
valid_sampled_tokens_count)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment