Commit 92579e9b authored by Casper Hansen's avatar Casper Hansen
Browse files

Support safetensors in benchmark

parent d5bb4ec8
......@@ -39,11 +39,12 @@ def generate(model, input_ids, n_generate):
return context_time, generate_time
def run_round(model_path, quant_file, n_generate, input_ids, batch_size):
def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safetensors):
print(f" -- Loading model...")
model = AutoAWQForCausalLM.from_quantized(
model_path, quant_file, fuse_layers=True,
max_new_tokens=n_generate, batch_size=batch_size
max_new_tokens=n_generate, batch_size=batch_size,
safetensors=safetensors
)
print(f" -- Warming up...")
......@@ -108,7 +109,8 @@ def main(args):
args.quant_file,
settings["n_generate"],
input_ids,
args.batch_size
args.batch_size,
args.safetensors
)
all_stats.append(stats)
......@@ -126,7 +128,8 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="casperhansen/vicuna-7b-v1.5-awq", help="path to the model")
parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation")
parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors")
args = parser.parse_args()
main(args)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment