"vscode:/vscode.git/clone" did not exist on "7492fae4c2cd16fb2783dce7e7583d7245cfbe92"
Commit 219ccb33 authored by Casper Hansen's avatar Casper Hansen
Browse files

Unify using variable safetensors

parent bb455d7c
......@@ -37,11 +37,11 @@ class AutoAWQForCausalLM:
@classmethod
def from_quantized(self, quant_path, quant_filename='pytorch_model.bin', max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1, use_safetensors=False) -> BaseAWQForCausalLM:
batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=use_safetensors
fuse_layers=fuse_layers, safetensors=safetensors
)
\ No newline at end of file
......@@ -217,7 +217,7 @@ class BaseAWQForCausalLM(nn.Module):
return awq_results
def save_quantized(self, save_dir, use_safetensors=False, shard_size="10GB"):
def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name='', search_result=None):
class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
......@@ -233,7 +233,7 @@ class BaseAWQForCausalLM(nn.Module):
torch.save(search_result, f'{save_dir}/{model_name}')
else:
# model_name has no extension, add it when saving state_dict
model_name = 'model.safetensors' if use_safetensors else 'pytorch_model.bin'
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
# shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint(
......@@ -243,7 +243,7 @@ class BaseAWQForCausalLM(nn.Module):
)
for shard_file, shard in shards.items():
if use_safetensors:
if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment