Unverified Commit 4155ec7f authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

pass device_map other than auto for parallelize (#2457)

* pass device_map other than auto for parallelize
parent 901053f6
...@@ -346,9 +346,9 @@ class HFLM(TemplateLM): ...@@ -346,9 +346,9 @@ class HFLM(TemplateLM):
== (self.accelerator.process_index % num_local_processes) == (self.accelerator.process_index % num_local_processes)
} }
args["max_memory"] = max_memory_per_gpu_map args["max_memory"] = max_memory_per_gpu_map
args["device_map"] = "auto" args["device_map"] = "auto" if device_map is None else device_map
eval_logger.info( eval_logger.info(
f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to 'auto'" f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
) )
if max_cpu_memory is not None: if max_cpu_memory is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment