Commit 435b3b4b authored by Casper Hansen's avatar Casper Hansen
Browse files

Default to empty string for model file

parent 720a1fce
...@@ -35,7 +35,7 @@ class AutoAWQForCausalLM: ...@@ -35,7 +35,7 @@ class AutoAWQForCausalLM:
) )
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename='pytorch_model.bin', max_new_tokens=None, def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True, device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=False) -> BaseAWQForCausalLM: batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
......
...@@ -284,7 +284,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -284,7 +284,7 @@ class BaseAWQForCausalLM(nn.Module):
) )
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename='pytorch_model.bin', def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, device='balanced', torch_dtype=torch.float16, max_new_tokens=None, device='balanced', torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True, trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'): fuse_layers=False, version='GEMM'):
...@@ -298,7 +298,10 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -298,7 +298,10 @@ class BaseAWQForCausalLM(nn.Module):
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns) model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
if model_filename != '':
model_weights_path = model_path + f'/{model_filename}' model_weights_path = model_path + f'/{model_filename}'
else:
model_weights_path = model_path
# [STEP 2] Load config and set sequence length # [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class # TODO: Create BaseAWQConfig class
...@@ -343,7 +346,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -343,7 +346,7 @@ class BaseAWQForCausalLM(nn.Module):
if is_quantized: if is_quantized:
load_checkpoint_in_model( load_checkpoint_in_model(
model, model,
checkpoint=model_path if safetensors else model_weights_path, checkpoint=model_weights_path,
device_map=device_map device_map=device_map
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment