Unverified Commit a49ea5a5 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] A bit more PP simplification (#34766)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 30ebe0dc
...@@ -1003,15 +1003,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1003,15 +1003,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, input_batch, kv_connector_output = self.execute_model_state hidden_states, input_batch, kv_connector_output = self.execute_model_state
self.execute_model_state = None # type: ignore self.execute_model_state = None # type: ignore
if not self.is_last_pp_rank:
# Non-last PP rank: hidden_states is None because this rank produced # Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the # IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state. # sampled tokens broadcast from the last rank and update local state.
if not self.is_last_pp_rank: sampled, num_sampled, num_rejected = pp_receive(
received = pp_receive(
input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1 input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
) )
assert received is not None
sampled, num_sampled, num_rejected = received
self.postprocess(input_batch, sampled, num_sampled, num_rejected) self.postprocess(input_batch, sampled, num_sampled, num_rejected)
return None return None
...@@ -1020,8 +1018,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1020,8 +1018,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, input_batch, grammar_output hidden_states, input_batch, grammar_output
) )
# Broadcast to non-last PP ranks (handles spec decode multi-token).
if self.use_pp: if self.use_pp:
# Broadcast to non-last PP ranks (handles spec decode multi-token).
pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected) pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs( prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
......
...@@ -13,8 +13,7 @@ def pp_broadcast( ...@@ -13,8 +13,7 @@ def pp_broadcast(
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
) -> None: ) -> None:
pp = get_pp_group() pp = get_pp_group()
if not pp.is_last_rank: assert pp.is_last_rank
return
assert sampled_token_ids.dtype == torch.int64 assert sampled_token_ids.dtype == torch.int64
torch.distributed.broadcast( torch.distributed.broadcast(
...@@ -27,10 +26,9 @@ def pp_broadcast( ...@@ -27,10 +26,9 @@ def pp_broadcast(
def pp_receive( def pp_receive(
num_reqs: int, max_sample_len: int = 1 num_reqs: int, max_sample_len: int = 1
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pp = get_pp_group() pp = get_pp_group()
if pp.is_last_rank: assert not pp.is_last_rank
return None
sampled_tokens = torch.empty( sampled_tokens = torch.empty(
num_reqs, max_sample_len, dtype=torch.int64, device=pp.device num_reqs, max_sample_len, dtype=torch.int64, device=pp.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment