Unverified Commit af2a2376 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[deepspeed] partial ZeRO-3 support (#3076)



* [deepspeed] partial ZeRO-3 support

* cleanup

* improve deepspeed fixes

* Improve

* make style

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent d71db894
...@@ -29,6 +29,7 @@ import torch.utils.checkpoint ...@@ -29,6 +29,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
...@@ -36,6 +37,7 @@ from packaging import version ...@@ -36,6 +37,7 @@ from packaging import version
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from transformers.utils import ContextManagers
import diffusers import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
...@@ -464,10 +466,34 @@ def main(): ...@@ -464,10 +466,34 @@ def main():
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
) )
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = AcceleratorState() if accelerate.state.is_initialized() else None
if deepspeed_plugin is None:
return []
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
# Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
# For this to work properly all models must be run through `accelerate.prepare`. But accelerate
# will try to assign the same optimizer with the same weights to all models during
# `deepspeed.initialize`, which of course doesn't work.
#
# For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
# frozen models from being partitioned during `zero.Init` which gets called during
# `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
text_encoder = CLIPTextModel.from_pretrained( text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
) )
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
) )
......
import contextlib
import copy import copy
import os import os
import random import random
...@@ -6,7 +7,11 @@ from typing import Any, Dict, Iterable, Optional, Union ...@@ -6,7 +7,11 @@ from typing import Any, Dict, Iterable, Optional, Union
import numpy as np import numpy as np
import torch import torch
from .utils import deprecate from .utils import deprecate, is_transformers_available
if is_transformers_available():
import transformers
def enable_full_determinism(seed: int): def enable_full_determinism(seed: int):
...@@ -197,7 +202,15 @@ class EMAModel: ...@@ -197,7 +202,15 @@ class EMAModel:
self.cur_decay_value = decay self.cur_decay_value = decay
one_minus_decay = 1 - decay one_minus_decay = 1 - decay
context_manager = contextlib.nullcontext
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed
for s_param, param in zip(self.shadow_params, parameters): for s_param, param in zip(self.shadow_params, parameters):
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
with context_manager():
if param.requires_grad: if param.requires_grad:
s_param.sub_(one_minus_decay * (s_param - param)) s_param.sub_(one_minus_decay * (s_param - param))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment