"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d71ecad8cde6edaa231253bfbf10d5f231d72203"
Commit 219ccb33 authored by Casper Hansen's avatar Casper Hansen
Browse files

Unify using variable safetensors

parent bb455d7c
...@@ -37,11 +37,11 @@ class AutoAWQForCausalLM: ...@@ -37,11 +37,11 @@ class AutoAWQForCausalLM:
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename='pytorch_model.bin', max_new_tokens=None, def from_quantized(self, quant_path, quant_filename='pytorch_model.bin', max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True, device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1, use_safetensors=False) -> BaseAWQForCausalLM: batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code, quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=use_safetensors fuse_layers=fuse_layers, safetensors=safetensors
) )
\ No newline at end of file
...@@ -217,7 +217,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -217,7 +217,7 @@ class BaseAWQForCausalLM(nn.Module):
return awq_results return awq_results
def save_quantized(self, save_dir, use_safetensors=False, shard_size="10GB"): def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name='', search_result=None): def _save_files(save_dir, model_name='', search_result=None):
class EmptyModule(nn.Module): class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__() def __init__(self): super(EmptyModule, self).__init__()
...@@ -233,7 +233,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -233,7 +233,7 @@ class BaseAWQForCausalLM(nn.Module):
torch.save(search_result, f'{save_dir}/{model_name}') torch.save(search_result, f'{save_dir}/{model_name}')
else: else:
# model_name has no extension, add it when saving state_dict # model_name has no extension, add it when saving state_dict
model_name = 'model.safetensors' if use_safetensors else 'pytorch_model.bin' model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
# shard checkpoint into chunks (10GB default) # shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint( shards, index = shard_checkpoint(
...@@ -243,7 +243,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -243,7 +243,7 @@ class BaseAWQForCausalLM(nn.Module):
) )
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
if use_safetensors: if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory # safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()} shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}) save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment