Unverified Commit 4155ec7f authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

pass device_map other than auto for parallelize (#2457)

* pass device_map other than auto for parallelize
parent 901053f6
......@@ -346,9 +346,9 @@ class HFLM(TemplateLM):
== (self.accelerator.process_index % num_local_processes)
}
args["max_memory"] = max_memory_per_gpu_map
args["device_map"] = "auto"
args["device_map"] = "auto" if device_map is None else device_map
eval_logger.info(
f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to 'auto'"
f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
)
if max_cpu_memory is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment