Commit 435b3b4b authored by Casper Hansen's avatar Casper Hansen
Browse files

Default to empty string for model file

parent 720a1fce
......@@ -35,7 +35,7 @@ class AutoAWQForCausalLM:
)
@classmethod
def from_quantized(self, quant_path, quant_filename='pytorch_model.bin', max_new_tokens=None,
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
......
......@@ -284,7 +284,7 @@ class BaseAWQForCausalLM(nn.Module):
)
@classmethod
def from_quantized(self, model_path, model_type, model_filename='pytorch_model.bin',
def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, device='balanced', torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'):
......@@ -298,7 +298,10 @@ class BaseAWQForCausalLM(nn.Module):
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
model_weights_path = model_path + f'/{model_filename}'
if model_filename != '':
model_weights_path = model_path + f'/{model_filename}'
else:
model_weights_path = model_path
# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
......@@ -343,7 +346,7 @@ class BaseAWQForCausalLM(nn.Module):
if is_quantized:
load_checkpoint_in_model(
model,
checkpoint=model_path if safetensors else model_weights_path,
checkpoint=model_weights_path,
device_map=device_map
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment