Commit dc2ce261 authored by 王敏's avatar 王敏
Browse files

[fix]修复test_spec_decode_worker中的错误

parent 137e8a16
...@@ -550,14 +550,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -550,14 +550,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states, execute_model_req.seq_group_metadata_list) hidden_states, execute_model_req.seq_group_metadata_list)
# Store logits from target model execution. # Store logits from target model execution.
logits = sampler_output.logits if self.tree_style_spec_decoding:
if logits is not None: logits = sampler_output.logits
if self.previous_logits is None: if logits is not None:
self.previous_logits = Logits( if self.previous_logits is None:
logits, execute_model_req.seq_group_metadata_list) self.previous_logits = Logits(
else: logits, execute_model_req.seq_group_metadata_list)
self.previous_logits.update( else:
logits, execute_model_req.seq_group_metadata_list) self.previous_logits.update(
logits, execute_model_req.seq_group_metadata_list)
if not skip_proposer: if not skip_proposer:
# We prepare the prefill hidden states here so that there no # We prepare the prefill hidden states here so that there no
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment