Commit 4e7ada89 authored by Abhinav Kulkarni's avatar Abhinav Kulkarni
Browse files

[Minor] Added max-memory command line paramemter

parent d32095ab
......@@ -20,6 +20,12 @@ parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser.add_argument('--parallel', action='store_true',
help="enable model parallelism")
# max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*',
help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
+ "Example: 0:10GiB 1:10GiB cpu:20GiB; " \
+ "mode details here: " \
+ "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
parser.add_argument('--auto_parallel', action='store_true',
help="automatically set parallel and batch_size")
# quantization config
......@@ -43,6 +49,9 @@ parser.add_argument('--load_awq', type=str, default=None,
help="load the awq search results")
args = parser.parse_args()
max_memory = [v.split(':') for v in (args.max_memory or "")]
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}
if args.auto_parallel:
gpu_list = auto_parallel(args)
......@@ -115,7 +124,7 @@ def build_model_and_enc(model_path):
else:
# Inference with fake quant
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16}
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
......@@ -151,8 +160,9 @@ def build_model_and_enc(model_path):
kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto",
"max_memory": {0: "8GiB", "cpu": "99GiB"}
}
if len(max_memory):
kwargs["max_memory"] = max_memory
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, state_dict=model.state_dict(), trust_remote_code=True, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment