Commit 741dbbbb authored by zhuwenwen's avatar zhuwenwen
Browse files

update mlp

parent 9d5e4dd9
......@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm import _custom_ops as ops
from vllm.distributed import tensor_model_parallel_all_gather, tensor_model_parallel_gather
from vllm import envs
SQRT2 = 2**0.5
......@@ -215,7 +216,7 @@ class MLPSpeculator(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn:
if self.use_llama_nn or envs.VLLM_USE_NN:
if (os.environ['LM_NN'] == '1' and "head" in name) or "proj" in name:
_weight = torch.zeros_like(param.data)
ori_shape =_weight.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment