Commit f6ce3afa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev' of ssh://10.6.10.68:10022/dcutoolkit/deeplearing/vllm into v0.6.2-dev

parents 78800ecf 1a313afb
...@@ -244,7 +244,7 @@ class Medusa(nn.Module): ...@@ -244,7 +244,7 @@ class Medusa(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and "lm_head" in name: if self.use_llama_nn and os.environ['LM_NN'] == '1' and "lm_head" in name:
_weight = torch.zeros_like(param.data) _weight = torch.zeros_like(param.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -201,7 +201,7 @@ class MLPSpeculator(nn.Module): ...@@ -201,7 +201,7 @@ class MLPSpeculator(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and "head" in name: if self.use_llama_nn and os.environ['LM_NN'] == '1' and "head" in name:
_weight = torch.zeros_like(param.data) _weight = torch.zeros_like(param.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -531,6 +531,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -531,6 +531,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
not called, meaning that the kv-cache in proposer for requests is not not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding. updated, so they cannot enable spec decode in the rest decoding.
""" """
if self.tree_style_spec_decoding and self.kvcache_slot_to_be_moved is not None:
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
self.kvcache_slot_to_be_moved = None
sampler_output = self.scorer_worker.execute_model(execute_model_req) sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1 assert len(sampler_output) == 1
...@@ -734,7 +737,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -734,7 +737,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get probabilities according to proposal method. # Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
if non_spec_indices: if proposal_probs is not None and non_spec_indices:
proposal_probs = proposal_probs[spec_indices] proposal_probs = proposal_probs[spec_indices]
# Get proposed tokens. # Get proposed tokens.
...@@ -744,7 +747,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -744,7 +747,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get tree buffers. # Get tree buffers.
cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
if non_spec_indices: if cart_candidates is not None and non_spec_indices:
cart_candidates = cart_candidates[spec_indices] cart_candidates = cart_candidates[spec_indices]
# Sampler arguments # Sampler arguments
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment