Commit db8ba322 authored by Casper Hansen's avatar Casper Hansen
Browse files

Add safetensors support

parent 97d38e29
......@@ -35,12 +35,13 @@ class AutoAWQForCausalLM:
)
@classmethod
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None,
def from_quantized(self, quant_path, quant_filename='pytorch_model.bin', max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1) -> BaseAWQForCausalLM:
batch_size=1, use_safetensors=False) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code, fuse_layers=fuse_layers
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=use_safetensors
)
\ No newline at end of file
......@@ -11,6 +11,7 @@ from safetensors.torch import save_file
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
from awq.utils.utils import simple_dispatch_model
from awq.utils.calib_data import get_calib_dataset
from transformers.modeling_utils import shard_checkpoint
from awq.quantize.quantizer import pseudo_quantize_tensor
......@@ -18,7 +19,7 @@ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM(nn.Module):
......@@ -217,7 +218,7 @@ class BaseAWQForCausalLM(nn.Module):
return awq_results
def save_quantized(self, save_dir, use_safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name, search_result=None):
def _save_files(save_dir, model_name='', search_result=None):
class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
......@@ -232,7 +233,7 @@ class BaseAWQForCausalLM(nn.Module):
torch.save(search_result, f'{save_dir}/{model_name}')
else:
# model_name has no extension, add it when saving state_dict
model_name += '.safetensors' if use_safetensors else '.bin'
model_name = 'model.safetensors' if use_safetensors else 'pytorch_model.bin'
# shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint(
......@@ -262,8 +263,7 @@ class BaseAWQForCausalLM(nn.Module):
# Save model
if self.search_result is None or self.is_quantized:
model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}'
_save_files(save_dir, model_name, search_result=None)
_save_files(save_dir, '', search_result=None)
else:
model_name = 'awq_model_search_result.pt'
_save_files(save_dir, model_name, self.search_result)
......@@ -284,9 +284,10 @@ class BaseAWQForCausalLM(nn.Module):
)
@classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'):
def from_quantized(self, model_path, model_type, model_filename='pytorch_model.bin',
max_new_tokens=None, device='balanced', torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'):
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"]
......@@ -297,8 +298,7 @@ class BaseAWQForCausalLM(nn.Module):
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
# TODO: Better naming, model_filename becomes a directory
model_filename = model_path + f'/{model_filename}'
model_weights_path = model_path + f'/{model_filename}'
# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
......@@ -341,13 +341,14 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights
if is_quantized:
model = load_checkpoint_and_dispatch(
load_checkpoint_in_model(
model,
model_filename,
device_map=device_map,
no_split_module_classes=[self.layer_type]
checkpoint=model_path if safetensors else model_weights_path,
device_map=device_map
)
model = simple_dispatch_model(model, device_map)
if fuse_layers:
self.fuse_layers(model, quant_config)
......@@ -357,7 +358,7 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights
model = AutoModelForCausalLM.from_pretrained(
model_filename,
model_weights_path,
device_map=device_map,
trust_remote_code=trust_remote_code,
offload_folder="offload",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment