"vscode:/vscode.git/clone" did not exist on "29564ad62b6e9abb564d4005bbe3fba65179fc89"
Commit 04e177ad authored by Casper Hansen's avatar Casper Hansen
Browse files

Support safetensors

parent d73d13b2
......@@ -23,11 +23,11 @@ class AutoAWQForCausalLM:
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
@classmethod
def from_pretrained(self, model_path, trust_remote_code=True) -> BaseAWQForCausalLM:
def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path, model_type, trust_remote_code=trust_remote_code
model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors
)
@classmethod
......
......@@ -48,6 +48,7 @@ class BaseAWQForCausalLM(nn.Module):
if run_quant:
self._awq_quant()
self.is_quantized = True
def _awq_quant(self):
......@@ -224,7 +225,7 @@ class BaseAWQForCausalLM(nn.Module):
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model
if self.search_result is None:
if self.search_result is None and not self.is_quantized:
model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt'
_save_files(save_dir, model_name, self.model.state_dict())
else:
......@@ -233,7 +234,7 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
trust_remote_code=True):
trust_remote_code=True, safetensors=False):
return self.from_quantized(
model_path,
model_type,
......@@ -241,6 +242,7 @@ class BaseAWQForCausalLM(nn.Module):
device='balanced',
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
safetensors=safetensors,
is_quantized=False
)
......@@ -300,7 +302,7 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights
model = AutoModelForCausalLM.from_pretrained(
model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype
model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype, use_safetensors=safetensors
)
model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment