"vscode:/vscode.git/clone" did not exist on "5b681074119b970c2f99f8baea43f856cafc0251"
Unverified Commit a09c7ca9 authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[Chore][Spec Decode] Update check NoneType instead of assigning variables (#18836)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
parent 0e98964e
...@@ -146,16 +146,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -146,16 +146,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# req_id -> (input_id -> encoder_output) # req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# Set up speculative decoding.
self.use_spec_decode = False
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
if self.speculative_config: # Set up speculative decoding.
self.use_spec_decode = True
# NOTE(Jiayi): currently we put the entire draft model on # NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many # the last PP rank. This is not ideal if there are many
# layers in the draft model. # layers in the draft model.
if get_pp_group().is_last_rank: if self.speculative_config and get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config) self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.use_eagle(): elif self.speculative_config.use_eagle():
...@@ -1318,7 +1314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1318,7 +1314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
if not self.use_spec_decode: if not self.speculative_config:
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
spec_token_ids = None spec_token_ids = None
elif self.speculative_config.method == "ngram": elif self.speculative_config.method == "ngram":
...@@ -1740,7 +1736,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1740,7 +1736,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
hidden_states = outputs hidden_states = outputs
if self.use_spec_decode and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens) self.drafter.dummy_run(num_tokens)
...@@ -1795,7 +1791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1795,7 +1791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"initializing the engine.") from e "initializing the engine.") from e
else: else:
raise e raise e
if self.use_spec_decode: if self.speculative_config:
draft_token_ids = [[0] for _ in range(num_reqs)] draft_token_ids = [[0] for _ in range(num_reqs)]
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids, self.device) draft_token_ids, self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment