Commit d623e721 authored by 王敏's avatar 王敏
Browse files

[fix]修复单测test_mlp_correctness失败问题

parent 217ee621
......@@ -244,7 +244,7 @@ class Medusa(nn.Module):
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn and "lm_head" in name:
if self.use_llama_nn and os.environ['LM_NN'] == '1' and "lm_head" in name:
_weight = torch.zeros_like(param.data)
ori_shape =_weight.shape
......
......@@ -201,7 +201,7 @@ class MLPSpeculator(nn.Module):
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn and "head" in name:
if self.use_llama_nn and os.environ['LM_NN'] == '1' and "head" in name:
_weight = torch.zeros_like(param.data)
ori_shape =_weight.shape
......
......@@ -737,7 +737,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
if non_spec_indices:
if proposal_probs is not None and non_spec_indices:
proposal_probs = proposal_probs[spec_indices]
# Get proposed tokens.
......@@ -747,7 +747,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get tree buffers.
cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
if non_spec_indices:
if cart_candidates is not None and non_spec_indices:
cart_candidates = cart_candidates[spec_indices]
# Sampler arguments
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment