Unverified Commit 74821781 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

default fast model loading 🔥 (#1115)



* make accelerate hard dep

* default fast init

* move params to cpu when device map is None

* handle device_map=None

* handle torch < 1.9

* remove device_map="auto"

* style

* add accelerate in torch extra

* remove accelerate from extras["test"]

* raise an error if torch is available but not accelerate

* update installation docs

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* improve defautl loading speed even further, allow disabling fats loading

* address review comments

* adapt the tests

* fix test_stable_diffusion_fast_load

* fix test_read_init

* temp fix for dummy checks

* Trigger Build

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent ef2ea33c
......@@ -364,11 +364,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
)
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
safety_checker=None,
device_map="auto",
)
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
......@@ -411,7 +407,6 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
model_id,
scheduler=lms,
safety_checker=None,
device_map="auto",
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -468,7 +463,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......
......@@ -52,13 +52,13 @@ class CheckDummiesTester(unittest.TestCase):
def test_read_init(self):
objects = read_init()
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
self.assertIn("torch", objects)
self.assertIn("torch_and_accelerate", objects)
self.assertIn("torch_and_transformers", objects)
self.assertIn("flax_and_transformers", objects)
self.assertIn("torch_and_transformers_and_onnx", objects)
# Likewise, we can't assert on the exact content of a key
self.assertIn("UNet2DModel", objects["torch"])
self.assertIn("UNet2DModel", objects["torch_and_accelerate"])
self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])
......
......@@ -128,7 +128,7 @@ class CustomPipelineTests(unittest.TestCase):
def test_load_pipeline_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id, device_map="auto")
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
pipeline = DiffusionPipeline.from_pretrained(
......@@ -138,7 +138,6 @@ class CustomPipelineTests(unittest.TestCase):
feature_extractor=feature_extractor,
torch_dtype=torch.float16,
revision="fp16",
device_map="auto",
)
pipeline.enable_attention_slicing()
pipeline = pipeline.to(torch_device)
......@@ -333,9 +332,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_smart_download(self):
model_id = "hf-internal-testing/unet-pipeline-dummy"
with tempfile.TemporaryDirectory() as tmpdirname:
_ = DiffusionPipeline.from_pretrained(
model_id, cache_dir=tmpdirname, force_download=True, device_map="auto"
)
_ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
local_repo_name = "--".join(["models"] + model_id.split("/"))
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
......@@ -359,7 +356,10 @@ class PipelineSlowTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
with CaptureLogger(logger) as cap_logger:
DiffusionPipeline.from_pretrained(
model_id, not_used=True, cache_dir=tmpdirname, force_download=True, device_map="auto"
model_id,
not_used=True,
cache_dir=tmpdirname,
force_download=True,
)
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
......@@ -383,7 +383,7 @@ class PipelineSlowTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname, device_map="auto")
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.manual_seed(0)
......@@ -399,11 +399,11 @@ class PipelineSlowTests(unittest.TestCase):
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm = ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None)
......@@ -421,14 +421,12 @@ class PipelineSlowTests(unittest.TestCase):
scheduler = DDPMScheduler(num_train_timesteps=10)
# pass unet into DiffusionPipeline
unet = UNet2DModel.from_pretrained(model_path, device_map="auto")
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(
model_path, unet=unet, scheduler=scheduler, device_map="auto"
)
unet = UNet2DModel.from_pretrained(model_path)
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
......@@ -443,7 +441,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path, device_map="auto")
pipe = DDIMPipeline.from_pretrained(model_path)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -467,7 +465,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_ddpm_ddim_equality(self, seed):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()
......@@ -498,7 +496,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_ddpm_ddim_equality_batched(self, seed):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment