Commit 0cf05716 authored by jujl1's avatar jujl1
Browse files

fix: 修复丢弃MTP代码报错

parent c1795786
...@@ -1643,9 +1643,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1643,9 +1643,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler # Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back. # doesn't need to send them back.
...@@ -1685,7 +1682,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1685,7 +1682,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata, spec_decode_metadata,
attn_metadata, attn_metadata,
) )
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment