Commit db8ba322 authored by Casper Hansen's avatar Casper Hansen
Browse files

Add safetensors support

parent 97d38e29
...@@ -35,12 +35,13 @@ class AutoAWQForCausalLM: ...@@ -35,12 +35,13 @@ class AutoAWQForCausalLM:
) )
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None, def from_quantized(self, quant_path, quant_filename='pytorch_model.bin', max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True, device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1) -> BaseAWQForCausalLM: batch_size=1, use_safetensors=False) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code, fuse_layers=fuse_layers quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=use_safetensors
) )
\ No newline at end of file
...@@ -11,6 +11,7 @@ from safetensors.torch import save_file ...@@ -11,6 +11,7 @@ from safetensors.torch import save_file
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from awq.utils.utils import simple_dispatch_model
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
from awq.quantize.quantizer import pseudo_quantize_tensor from awq.quantize.quantizer import pseudo_quantize_tensor
...@@ -18,7 +19,7 @@ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV ...@@ -18,7 +19,7 @@ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM(nn.Module): class BaseAWQForCausalLM(nn.Module):
...@@ -217,7 +218,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -217,7 +218,7 @@ class BaseAWQForCausalLM(nn.Module):
return awq_results return awq_results
def save_quantized(self, save_dir, use_safetensors=False, shard_size="10GB"): def save_quantized(self, save_dir, use_safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name, search_result=None): def _save_files(save_dir, model_name='', search_result=None):
class EmptyModule(nn.Module): class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__() def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x def forward(self, x): return x
...@@ -232,7 +233,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -232,7 +233,7 @@ class BaseAWQForCausalLM(nn.Module):
torch.save(search_result, f'{save_dir}/{model_name}') torch.save(search_result, f'{save_dir}/{model_name}')
else: else:
# model_name has no extension, add it when saving state_dict # model_name has no extension, add it when saving state_dict
model_name += '.safetensors' if use_safetensors else '.bin' model_name = 'model.safetensors' if use_safetensors else 'pytorch_model.bin'
# shard checkpoint into chunks (10GB default) # shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint( shards, index = shard_checkpoint(
...@@ -262,8 +263,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -262,8 +263,7 @@ class BaseAWQForCausalLM(nn.Module):
# Save model # Save model
if self.search_result is None or self.is_quantized: if self.search_result is None or self.is_quantized:
model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}' _save_files(save_dir, '', search_result=None)
_save_files(save_dir, model_name, search_result=None)
else: else:
model_name = 'awq_model_search_result.pt' model_name = 'awq_model_search_result.pt'
_save_files(save_dir, model_name, self.search_result) _save_files(save_dir, model_name, self.search_result)
...@@ -284,9 +284,10 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -284,9 +284,10 @@ class BaseAWQForCausalLM(nn.Module):
) )
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None, def from_quantized(self, model_path, model_type, model_filename='pytorch_model.bin',
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, max_new_tokens=None, device='balanced', torch_dtype=torch.float16,
safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'): trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"] ignore_patterns = ["*msgpack*", "*h5*"]
...@@ -297,8 +298,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -297,8 +298,7 @@ class BaseAWQForCausalLM(nn.Module):
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns) model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
# TODO: Better naming, model_filename becomes a directory model_weights_path = model_path + f'/{model_filename}'
model_filename = model_path + f'/{model_filename}'
# [STEP 2] Load config and set sequence length # [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class # TODO: Create BaseAWQConfig class
...@@ -341,13 +341,14 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -341,13 +341,14 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights # Load model weights
if is_quantized: if is_quantized:
model = load_checkpoint_and_dispatch( load_checkpoint_in_model(
model, model,
model_filename, checkpoint=model_path if safetensors else model_weights_path,
device_map=device_map, device_map=device_map
no_split_module_classes=[self.layer_type]
) )
model = simple_dispatch_model(model, device_map)
if fuse_layers: if fuse_layers:
self.fuse_layers(model, quant_config) self.fuse_layers(model, quant_config)
...@@ -357,7 +358,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -357,7 +358,7 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights # Load model weights
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_filename, model_weights_path,
device_map=device_map, device_map=device_map,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
offload_folder="offload", offload_folder="offload",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment