Commit 471f811b authored by Casper Hansen's avatar Casper Hansen
Browse files

Increase robustness of model loading

parent 916fdf97
...@@ -2,18 +2,19 @@ import os ...@@ -2,18 +2,19 @@ import os
import gc import gc
import torch import torch
import functools import functools
import accelerate
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from collections import defaultdict from collections import defaultdict
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from awq.quantize.quantizer import pseudo_quantize_tensor from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation from awq.quantize.qmodule import WQLinear, ScaledActivation
from awq.quantize.auto_clip import auto_clip_block, apply_clip from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale from awq.quantize.auto_scale import auto_scale_block, apply_scale
from accelerate import init_empty_weights, load_checkpoint_and_dispatch from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM: class BaseAWQForCausalLM:
...@@ -215,10 +216,17 @@ class BaseAWQForCausalLM: ...@@ -215,10 +216,17 @@ class BaseAWQForCausalLM:
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename, w_bit=4, q_config={}, def from_quantized(self, model_path, model_type, model_filename, w_bit=4, q_config={},
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, is_quantized=True): device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True):
# Download model if path is not a directory # Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
model_path = snapshot_download(model_path) ignore_patterns = ["*msgpack*", "*h5*"]
if safetensors:
ignore_patterns.extend(["*.pt", "*.bin"])
else:
ignore_patterns.append("*safetensors*")
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
# TODO: Better naming, model_filename becomes a directory # TODO: Better naming, model_filename becomes a directory
model_filename = model_path + f'/{model_filename}' model_filename = model_path + f'/{model_filename}'
...@@ -230,15 +238,33 @@ class BaseAWQForCausalLM: ...@@ -230,15 +238,33 @@ class BaseAWQForCausalLM:
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code) model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Only need to replace layers if a model is AWQ quantized # Only need to replace layers if a model is AWQ quantized
if is_quantized: if is_quantized:
# Prepare WQLinear layers, replace nn.Linear # Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, w_bit, q_config) self._load_quantized_modules(self, model, w_bit, q_config)
model.tie_weights() model.tie_weights()
# Load model weights # Load model weights
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type]) try:
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
except Exception as ex:
# Fallback to auto model if load_checkpoint_and_dispatch is not working
print(f'{ex} - falling back to AutoModelForCausalLM.from_pretrained')
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
del model
# Load model weights
model = AutoModelForCausalLM.from_pretrained(
model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype
)
model.eval()
return self(model, model_type, is_quantized=is_quantized) return self(model, model_type, is_quantized=is_quantized)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment