Commit ed618bb0 authored by Casper Hansen's avatar Casper Hansen
Browse files

Simplify saving model

parent 80996d1d
...@@ -55,55 +55,46 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -55,55 +55,46 @@ class BaseAWQForCausalLM(nn.Module):
pass pass
def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"): def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name='', search_result=None): save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
# Save model files with empty state dict
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
# Remove empty state dict # Save model
os.remove(f'{save_dir}/pytorch_model.bin') class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
if search_result is not None: # Save model files with empty state dict
torch.save(search_result, f'{save_dir}/{model_name}') self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
else:
# model_name has no extension, add it when saving state_dict
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
# shard checkpoint into chunks (10GB default) # Remove empty state dict
shards, index = shard_checkpoint( os.remove(f'{save_dir}/pytorch_model.bin')
self.model.state_dict(),
max_shard_size=shard_size,
weights_name=model_name
)
for shard_file, shard in shards.items(): # model_name has no extension, add it when saving state_dict
if safetensors: model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(save_dir, shard_file))
# save shard index # shard checkpoint into chunks (10GB default)
if index is not None: shards, index = shard_checkpoint(
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file: self.model.state_dict(),
file.write(json.dumps(index, indent=4)) max_shard_size=shard_size,
weights_name=model_name
)
# Save config for shard_file, shard in shards.items():
with open(f'{save_dir}/quant_config.json', 'w+') as file: if safetensors:
file.write(json.dumps(self.quant_config, indent=4)) # safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(save_dir, shard_file))
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir # save shard index
if index is not None:
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
file.write(json.dumps(index, indent=4))
# Save model # Save config
if self.search_result is None or self.is_quantized: with open(f'{save_dir}/quant_config.json', 'w+') as file:
_save_files(save_dir, '', search_result=None) file.write(json.dumps(self.quant_config, indent=4))
else:
model_name = 'awq_model_search_result.pt'
_save_files(save_dir, model_name, self.search_result)
@classmethod @classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment