Commit 471f811b authored by Casper Hansen's avatar Casper Hansen
Browse files

Increase robustness of model loading

parent 916fdf97
......@@ -2,18 +2,19 @@ import os
import gc
import torch
import functools
import accelerate
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM:
......@@ -215,10 +216,17 @@ class BaseAWQForCausalLM:
@classmethod
def from_quantized(self, model_path, model_type, model_filename, w_bit=4, q_config={},
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, is_quantized=True):
device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True):
# Download model if path is not a directory
if not os.path.isdir(model_path):
model_path = snapshot_download(model_path)
ignore_patterns = ["*msgpack*", "*h5*"]
if safetensors:
ignore_patterns.extend(["*.pt", "*.bin"])
else:
ignore_patterns.append("*safetensors*")
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
# TODO: Better naming, model_filename becomes a directory
model_filename = model_path + f'/{model_filename}'
......@@ -230,15 +238,33 @@ class BaseAWQForCausalLM:
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Only need to replace layers if a model is AWQ quantized
# Only need to replace layers if a model is AWQ quantized
if is_quantized:
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, w_bit, q_config)
model.tie_weights()
# Load model weights
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
try:
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
except Exception as ex:
# Fallback to auto model if load_checkpoint_and_dispatch is not working
print(f'{ex} - falling back to AutoModelForCausalLM.from_pretrained')
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
del model
# Load model weights
model = AutoModelForCausalLM.from_pretrained(
model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype
)
model.eval()
return self(model, model_type, is_quantized=is_quantized)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment