Unverified Commit b6d5a172 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] Fix error-handling (#35063)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 5e58bdc7
...@@ -227,6 +227,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -227,6 +227,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV Connector if configured. # KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: tuple | None = None
def update_max_model_len(self, max_model_len: int) -> None: def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len self.req_states.max_model_len = max_model_len
...@@ -388,6 +391,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -388,6 +391,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert self.execute_model_state is not None assert self.execute_model_state is not None
hidden_states, _, input_batch, _ = self.execute_model_state hidden_states, _, input_batch, _ = self.execute_model_state
self.execute_model_state = None
assert hidden_states is not None # Last PP rank always has hidden_states assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states return hidden_states, sample_hidden_states
...@@ -1036,18 +1040,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1036,18 +1040,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states, aux_hidden_states,
input_batch, input_batch,
kv_connector_output, kv_connector_output,
) # type: ignore )
return None return None
@torch.inference_mode() @torch.inference_mode()
def sample_tokens( def sample_tokens(
self, grammar_output: GrammarOutput | None self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput | None: ) -> AsyncOutput | ModelRunnerOutput | None:
assert self.execute_model_state is not None if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
hidden_states, aux_hidden_states, input_batch, kv_connector_output = ( hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
self.execute_model_state self.execute_model_state
) )
self.execute_model_state = None # type: ignore self.execute_model_state = None
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
# Non-last PP rank: hidden_states is None because this rank produced # Non-last PP rank: hidden_states is None because this rank produced
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment