Commit 733236e9 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

Update model_utils.py

parent 11dd2295
import bitsandbytes as bnb
import torch
from transformers.trainer_pt_utils import LabelSmoother
......@@ -6,7 +5,7 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
def find_all_linear_names(use_8bit, model):
cls = bnb.nn.Linear8bitLt if use_8bit else torch.nn.Linear
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment