import os import torch import gc import logging def auto_parallel(args): model_size = args.model_path.split("-")[-1] if model_size.endswith("m"): model_gb = 1 else: model_gb = float(model_size[:-1]) if model_gb < 20: n_gpu = 1 elif model_gb < 50: n_gpu = 4 else: n_gpu = 8 args.parallel = n_gpu > 1 cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) if isinstance(cuda_visible_devices, str): cuda_visible_devices = cuda_visible_devices.split(",") else: cuda_visible_devices = list(range(8)) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( [str(dev) for dev in cuda_visible_devices[:n_gpu]] ) logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"]) return cuda_visible_devices