"gallery/transforms/plot_transforms_illustrations.py" did not exist on "408917d19b151297831dc1dab481c029db99f5a5"
Commit 97d38e29 authored by Casper Hansen's avatar Casper Hansen
Browse files

Implement saving sharded weights + safetensors

parent d76125bf
......@@ -7,10 +7,12 @@ import functools
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from safetensors.torch import save_file
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset
from transformers.modeling_utils import shard_checkpoint
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip
......@@ -214,20 +216,43 @@ class BaseAWQForCausalLM(nn.Module):
return awq_results
def save_quantized(self, save_dir):
def _save_files(save_dir, model_name, model):
def save_quantized(self, save_dir, use_safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name, search_result=None):
class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
# Save model fiels without search results
# Save model files with empty state dict
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
# Remove empty module
# Remove empty state dict
os.remove(f'{save_dir}/pytorch_model.bin')
# Save search results
torch.save(model, f'{save_dir}/{model_name}')
if search_result is not None:
torch.save(search_result, f'{save_dir}/{model_name}')
else:
# model_name has no extension, add it when saving state_dict
model_name += '.safetensors' if use_safetensors else '.bin'
# shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint(
self.model.state_dict(),
max_shard_size=shard_size,
weights_name=model_name
)
for shard_file, shard in shards.items():
if use_safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(save_dir, shard_file))
# save shard index
if index is not None:
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
file.write(json.dumps(index, indent=4))
# Save config
with open(f'{save_dir}/quant_config.json', 'w+') as file:
......@@ -237,8 +262,8 @@ class BaseAWQForCausalLM(nn.Module):
# Save model
if self.search_result is None or self.is_quantized:
model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt'
_save_files(save_dir, model_name, self.model.state_dict())
model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}'
_save_files(save_dir, model_name, search_result=None)
else:
model_name = 'awq_model_search_result.pt'
_save_files(save_dir, model_name, self.search_result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment