Unverified Commit 8788fe10 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #47 from casper-hansen/safetensors

Safetensors and model sharding
parents 2a3e0fa1 4ecd8594
...@@ -35,12 +35,13 @@ class AutoAWQForCausalLM: ...@@ -35,12 +35,13 @@ class AutoAWQForCausalLM:
) )
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None, def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True, device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1) -> BaseAWQForCausalLM: batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code, fuse_layers=fuse_layers quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors
) )
\ No newline at end of file
...@@ -8,16 +8,19 @@ import torch.nn as nn ...@@ -8,16 +8,19 @@ import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import List, Union from typing import List, Union
from collections import defaultdict from collections import defaultdict
from safetensors.torch import save_file
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from awq.utils.utils import simple_dispatch_model
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from transformers.modeling_utils import shard_checkpoint
from awq.quantize.quantizer import pseudo_quantize_tensor from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM(nn.Module): class BaseAWQForCausalLM(nn.Module):
...@@ -222,20 +225,43 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -222,20 +225,43 @@ class BaseAWQForCausalLM(nn.Module):
return awq_results return awq_results
def save_quantized(self, save_dir): def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name, model): def _save_files(save_dir, model_name='', search_result=None):
class EmptyModule(nn.Module): class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__() def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x def forward(self, x): return x
# Save model fiels without search results # Save model files with empty state dict
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict()) self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
# Remove empty module # Remove empty state dict
os.remove(f'{save_dir}/pytorch_model.bin') os.remove(f'{save_dir}/pytorch_model.bin')
# Save search results if search_result is not None:
torch.save(model, f'{save_dir}/{model_name}') torch.save(search_result, f'{save_dir}/{model_name}')
else:
# model_name has no extension, add it when saving state_dict
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
# shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint(
self.model.state_dict(),
max_shard_size=shard_size,
weights_name=model_name
)
for shard_file, shard in shards.items():
if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(save_dir, shard_file))
# save shard index
if index is not None:
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
file.write(json.dumps(index, indent=4))
# Save config # Save config
with open(f'{save_dir}/quant_config.json', 'w+') as file: with open(f'{save_dir}/quant_config.json', 'w+') as file:
...@@ -245,8 +271,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -245,8 +271,7 @@ class BaseAWQForCausalLM(nn.Module):
# Save model # Save model
if self.search_result is None or self.is_quantized: if self.search_result is None or self.is_quantized:
model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt' _save_files(save_dir, '', search_result=None)
_save_files(save_dir, model_name, self.model.state_dict())
else: else:
model_name = 'awq_model_search_result.pt' model_name = 'awq_model_search_result.pt'
_save_files(save_dir, model_name, self.search_result) _save_files(save_dir, model_name, self.search_result)
...@@ -267,21 +292,24 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -267,21 +292,24 @@ class BaseAWQForCausalLM(nn.Module):
) )
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None, def from_quantized(self, model_path, model_type, model_filename='',
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, max_new_tokens=None, device='balanced', torch_dtype=torch.float16,
safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'): trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"] ignore_patterns = ["*msgpack*", "*h5*"]
if safetensors: if safetensors:
ignore_patterns.extend(["*.pt", "*.bin"]) ignore_patterns.extend(["*.pt*", "*.bin*"])
else: else:
ignore_patterns.append("*safetensors*") ignore_patterns.append("*.safetensors*")
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns) model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
# TODO: Better naming, model_filename becomes a directory if model_filename != '':
model_filename = model_path + f'/{model_filename}' model_weights_path = model_path + f'/{model_filename}'
else:
model_weights_path = model_path
# [STEP 2] Load config and set sequence length # [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class # TODO: Create BaseAWQConfig class
...@@ -324,13 +352,14 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -324,13 +352,14 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights # Load model weights
if is_quantized: if is_quantized:
model = load_checkpoint_and_dispatch( load_checkpoint_in_model(
model, model,
model_filename, checkpoint=model_weights_path,
device_map=device_map, device_map=device_map
no_split_module_classes=[self.layer_type]
) )
model = simple_dispatch_model(model, device_map)
if fuse_layers: if fuse_layers:
self.fuse_layers(model, quant_config) self.fuse_layers(model, quant_config)
...@@ -340,7 +369,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -340,7 +369,7 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights # Load model weights
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_filename, model_weights_path,
device_map=device_map, device_map=device_map,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
offload_folder="offload", offload_folder="offload",
......
...@@ -6,6 +6,7 @@ quant_path = 'vicuna-7b-v1.5-awq' ...@@ -6,6 +6,7 @@ quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model # Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(model_path) model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
...@@ -13,6 +14,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) ...@@ -13,6 +14,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.quantize(tokenizer, quant_config=quant_config) model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model # Save quantized model
# NOTE: pass safetensors=True to save quantized model weights as safetensors
model.save_quantized(quant_path) model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path) tokenizer.save_pretrained(quant_path)
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "casperhansen/opt-125m-awq"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, safetensors=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = """\
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
USER: {prompt}
ASSISTANT:"""
tokens = tokenizer(
prompt_template.format(prompt="How are you today?"),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment