"graphbolt/src/cuda/sampling_utils.hip" did not exist on "3795a006b91c94291b911f0daa261c0598d7ffd8"
Commit ed618bb0 authored by Casper Hansen's avatar Casper Hansen
Browse files

Simplify saving model

parent 80996d1d
...@@ -55,7 +55,9 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -55,7 +55,9 @@ class BaseAWQForCausalLM(nn.Module):
pass pass
def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"): def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name='', search_result=None): save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model
class EmptyModule(nn.Module): class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__() def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x def forward(self, x): return x
...@@ -66,9 +68,6 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -66,9 +68,6 @@ class BaseAWQForCausalLM(nn.Module):
# Remove empty state dict # Remove empty state dict
os.remove(f'{save_dir}/pytorch_model.bin') os.remove(f'{save_dir}/pytorch_model.bin')
if search_result is not None:
torch.save(search_result, f'{save_dir}/{model_name}')
else:
# model_name has no extension, add it when saving state_dict # model_name has no extension, add it when saving state_dict
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin' model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
...@@ -96,14 +95,6 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -96,14 +95,6 @@ class BaseAWQForCausalLM(nn.Module):
with open(f'{save_dir}/quant_config.json', 'w+') as file: with open(f'{save_dir}/quant_config.json', 'w+') as file:
file.write(json.dumps(self.quant_config, indent=4)) file.write(json.dumps(self.quant_config, indent=4))
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model
if self.search_result is None or self.is_quantized:
_save_files(save_dir, '', search_result=None)
else:
model_name = 'awq_model_search_result.pt'
_save_files(save_dir, model_name, self.search_result)
@classmethod @classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment