"example/git@developer.sourcefind.cn:OpenDAS/fastllm.git" did not exist on "aefd9f11273430b5491a31f3966fd5149ec86ef0"
Commit 92579e9b authored by Casper Hansen's avatar Casper Hansen
Browse files

Support safetensors in benchmark

parent d5bb4ec8
...@@ -39,11 +39,12 @@ def generate(model, input_ids, n_generate): ...@@ -39,11 +39,12 @@ def generate(model, input_ids, n_generate):
return context_time, generate_time return context_time, generate_time
def run_round(model_path, quant_file, n_generate, input_ids, batch_size): def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safetensors):
print(f" -- Loading model...") print(f" -- Loading model...")
model = AutoAWQForCausalLM.from_quantized( model = AutoAWQForCausalLM.from_quantized(
model_path, quant_file, fuse_layers=True, model_path, quant_file, fuse_layers=True,
max_new_tokens=n_generate, batch_size=batch_size max_new_tokens=n_generate, batch_size=batch_size,
safetensors=safetensors
) )
print(f" -- Warming up...") print(f" -- Warming up...")
...@@ -108,7 +109,8 @@ def main(args): ...@@ -108,7 +109,8 @@ def main(args):
args.quant_file, args.quant_file,
settings["n_generate"], settings["n_generate"],
input_ids, input_ids,
args.batch_size args.batch_size,
args.safetensors
) )
all_stats.append(stats) all_stats.append(stats)
...@@ -126,7 +128,8 @@ if __name__ == "__main__": ...@@ -126,7 +128,8 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="casperhansen/vicuna-7b-v1.5-awq", help="path to the model") parser.add_argument("--model_path", type=str, default="casperhansen/vicuna-7b-v1.5-awq", help="path to the model")
parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename") parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="weights filename") parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation")
parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment