Commit e04d0ec7 authored by Abhinav Kulkarni's avatar Abhinav Kulkarni
Browse files

[Minor] Added model dispatch to GPU logic

parent df0c600c
...@@ -84,19 +84,12 @@ def build_model_and_enc(model_path): ...@@ -84,19 +84,12 @@ def build_model_and_enc(model_path):
torch_dtype=torch.float16, trust_remote_code=True) torch_dtype=torch.float16, trust_remote_code=True)
real_quantize_model_weight( real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True) model, w_bit=args.w_bit, q_config=q_config, init_only=True)
# Passing empty max_memory={} causes error
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
model = load_checkpoint_and_dispatch( model = load_checkpoint_and_dispatch(
model, model, args.load_quant, device_map="balanced",
checkpoint=args.load_quant,
device_map="balanced",
# TODO: can we remove this? # TODO: can we remove this?
no_split_module_classes=[ no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"], "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
**kwargs
) )
else: # fp16 to quantized else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq args.run_awq &= not args.load_awq # if load_awq, no need to run awq
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True} kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment