Commit a293e16f authored by Jiaming Tang's avatar Jiaming Tang
Browse files

[Minor] fix loading awq checkpoint issue

parent 36913edb
...@@ -87,7 +87,13 @@ def build_model_and_enc(model_path): ...@@ -87,7 +87,13 @@ def build_model_and_enc(model_path):
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs) model_path, config=config, trust_remote_code=True, **kwargs)
if args.run_awq:
if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu")
apply_awq(model, awq_results)
elif args.run_awq:
awq_results = run_awq( awq_results = run_awq(
model, enc, model, enc,
w_bit=args.w_bit, q_config=q_config, w_bit=args.w_bit, q_config=q_config,
...@@ -97,11 +103,6 @@ def build_model_and_enc(model_path): ...@@ -97,11 +103,6 @@ def build_model_and_enc(model_path):
torch.save(awq_results, args.dump_awq) torch.save(awq_results, args.dump_awq)
print("AWQ results saved at", args.dump_awq) print("AWQ results saved at", args.dump_awq)
if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu")
apply_awq(model, awq_results)
# weight quantization # weight quantization
if args.w_bit is not None: if args.w_bit is not None:
if args.q_backend == "fake": if args.q_backend == "fake":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment