Unverified Commit a49ea5a5 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] A bit more PP simplification (#34766)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 30ebe0dc
......@@ -1003,15 +1003,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, input_batch, kv_connector_output = self.execute_model_state
self.execute_model_state = None # type: ignore
if not self.is_last_pp_rank:
# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state.
if not self.is_last_pp_rank:
received = pp_receive(
# sampled tokens broadcast from the last rank and update local state.
sampled, num_sampled, num_rejected = pp_receive(
input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
)
assert received is not None
sampled, num_sampled, num_rejected = received
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
return None
......@@ -1020,8 +1018,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, input_batch, grammar_output
)
# Broadcast to non-last PP ranks (handles spec decode multi-token).
if self.use_pp:
# Broadcast to non-last PP ranks (handles spec decode multi-token).
pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
......
......@@ -13,8 +13,7 @@ def pp_broadcast(
num_rejected: torch.Tensor,
) -> None:
pp = get_pp_group()
if not pp.is_last_rank:
return
assert pp.is_last_rank
assert sampled_token_ids.dtype == torch.int64
torch.distributed.broadcast(
......@@ -27,10 +26,9 @@ def pp_broadcast(
def pp_receive(
num_reqs: int, max_sample_len: int = 1
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pp = get_pp_group()
if pp.is_last_rank:
return None
assert not pp.is_last_rank
sampled_tokens = torch.empty(
num_reqs, max_sample_len, dtype=torch.int64, device=pp.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment