"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "0be0410443a684fa18db4a56a831eee00b8aa060"
Unverified Commit 46a0c6aa authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

feat: cuda device_map for pipelines. (#12122)

* feat: cuda device_map for pipelines.

* up

* up

* empty

* up
parent 421ee07e
...@@ -613,6 +613,9 @@ def _assign_components_to_devices( ...@@ -613,6 +613,9 @@ def _assign_components_to_devices(
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs): def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
# TODO: seperate out different device_map methods when it gets to it.
if device_map != "balanced":
return device_map
# To avoid circular import problem. # To avoid circular import problem.
from diffusers import pipelines from diffusers import pipelines
......
...@@ -108,7 +108,7 @@ LIBRARIES = [] ...@@ -108,7 +108,7 @@ LIBRARIES = []
for library in LOADABLE_CLASSES: for library in LOADABLE_CLASSES:
LIBRARIES.append(library) LIBRARIES.append(library)
SUPPORTED_DEVICE_MAP = ["balanced"] SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -988,12 +988,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -988,12 +988,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config) _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 7.1 device_map shenanigans # 7.1 device_map shenanigans
if final_device_map is not None and len(final_device_map) > 0: if final_device_map is not None:
component_device = final_device_map.get(name, None) if isinstance(final_device_map, dict) and len(final_device_map) > 0:
if component_device is not None: component_device = final_device_map.get(name, None)
current_device_map = {"": component_device} if component_device is not None:
else: current_device_map = {"": component_device}
current_device_map = None else:
current_device_map = None
elif isinstance(final_device_map, str):
current_device_map = final_device_map
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name class_name = class_name[4:] if class_name.startswith("Flax") else class_name
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
PyTorch utilities: Utilities related to PyTorch PyTorch utilities: Utilities related to PyTorch
""" """
import functools
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from . import logging from . import logging
...@@ -168,6 +169,7 @@ def get_torch_cuda_device_capability(): ...@@ -168,6 +169,7 @@ def get_torch_cuda_device_capability():
return None return None
@functools.lru_cache
def get_device(): def get_device():
if torch.cuda.is_available(): if torch.cuda.is_available():
return "cuda" return "cuda"
......
...@@ -2339,6 +2339,29 @@ class PipelineTesterMixin: ...@@ -2339,6 +2339,29 @@ class PipelineTesterMixin:
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}", f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
) )
@require_torch_accelerator
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = torch.manual_seed(0)
out = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map=torch_device)
for component in loaded_pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
inputs["generator"] = torch.manual_seed(0)
loaded_out = loaded_pipe(**inputs)[0]
max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
self.assertLess(max_diff, expected_max_difference)
@is_staging_test @is_staging_test
class PipelinePushToHubTester(unittest.TestCase): class PipelinePushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment