Commit 6371c3a0 authored by Jiaming Tang's avatar Jiaming Tang
Browse files

[Minor] Merge model initilization

parent e04d0ec7
...@@ -92,22 +92,14 @@ def build_model_and_enc(model_path): ...@@ -92,22 +92,14 @@ def build_model_and_enc(model_path):
) )
else: # fp16 to quantized else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq args.run_awq &= not args.load_awq # if load_awq, no need to run awq
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True} kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
if args.run_awq: if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq" assert args.dump_awq, "Please save the awq results with --dump_awq"
# Init model on CPU
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_normal_ = skip
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
awq_results = run_awq( awq_results = run_awq(
model, enc, model, enc,
w_bit=args.w_bit, q_config=q_config, w_bit=args.w_bit, q_config=q_config,
...@@ -121,11 +113,6 @@ def build_model_and_enc(model_path): ...@@ -121,11 +113,6 @@ def build_model_and_enc(model_path):
print("AWQ results saved at", args.dump_awq) print("AWQ results saved at", args.dump_awq)
exit(0) exit(0)
else:
# Inference with fake quant
# Init model on CPU:
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
if args.load_awq: if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq) print("Loading pre-computed AWQ results from", args.load_awq)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment