Commit 04e177ad authored by Casper Hansen's avatar Casper Hansen
Browse files

Support safetensors

parent d73d13b2
...@@ -23,11 +23,11 @@ class AutoAWQForCausalLM: ...@@ -23,11 +23,11 @@ class AutoAWQForCausalLM:
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained') 'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
@classmethod @classmethod
def from_pretrained(self, model_path, trust_remote_code=True) -> BaseAWQForCausalLM: def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code) model_type = check_and_get_model_type(model_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path, model_type, trust_remote_code=trust_remote_code model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors
) )
@classmethod @classmethod
......
...@@ -48,6 +48,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -48,6 +48,7 @@ class BaseAWQForCausalLM(nn.Module):
if run_quant: if run_quant:
self._awq_quant() self._awq_quant()
self.is_quantized = True
def _awq_quant(self): def _awq_quant(self):
...@@ -224,7 +225,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -224,7 +225,7 @@ class BaseAWQForCausalLM(nn.Module):
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model # Save model
if self.search_result is None: if self.search_result is None and not self.is_quantized:
model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt' model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt'
_save_files(save_dir, model_name, self.model.state_dict()) _save_files(save_dir, model_name, self.model.state_dict())
else: else:
...@@ -233,7 +234,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -233,7 +234,7 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod @classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
trust_remote_code=True): trust_remote_code=True, safetensors=False):
return self.from_quantized( return self.from_quantized(
model_path, model_path,
model_type, model_type,
...@@ -241,6 +242,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -241,6 +242,7 @@ class BaseAWQForCausalLM(nn.Module):
device='balanced', device='balanced',
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
safetensors=safetensors,
is_quantized=False is_quantized=False
) )
...@@ -300,7 +302,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -300,7 +302,7 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights # Load model weights
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype, use_safetensors=safetensors
) )
model.eval() model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment