Unverified Commit 2ada094b authored by drhead's avatar drhead Committed by GitHub
Browse files

Add extra performance features for EMAModel, torch._foreach operations and...


Add extra performance features for EMAModel, torch._foreach operations and better support for non-blocking CPU offloading (#7685)

* Add support for _foreach operations and non-blocking to EMAModel

* default foreach to false

* add non-blocking EMA offloading to SD1.5 T2I example script

* fix whitespace

* move foreach to cli argument

* linting

* Update README.md re: EMA weight training

* correct args.foreach_ema

* add tests for foreach ema

* code quality

* add foreach to from_pretrained

* default foreach false

* fix linting

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatardrhead <a@a.a>
parent f1f542bd
...@@ -170,11 +170,19 @@ For our small Narutos dataset, the effects of Min-SNR weighting strategy might n ...@@ -170,11 +170,19 @@ For our small Narutos dataset, the effects of Min-SNR weighting strategy might n
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
#### Training with EMA weights
Through the `EMAModel` class, we support a convenient method of tracking an exponential moving average of model parameters. This helps to smooth out noise in model parameter updates and generally improves model performance. If enabled with the `--use_ema` argument, the final model checkpoint that is saved at the end of training will use the EMA weights.
EMA weights require an additional full-precision copy of the model parameters to be stored in memory, but otherwise have very little performance overhead. `--foreach_ema` can be used to further reduce the overhead. If you are short on VRAM and still want to use EMA weights, you can store them in CPU RAM by using the `--offload_ema` argument. This will keep the EMA weights in pinned CPU memory during the training step. Then, once every model parameter update, it will transfer the EMA weights back to the GPU which can then update the parameters on the GPU, before sending them back to the CPU. Both of these transfers are set up as non-blocking, so CUDA devices should be able to overlap this transfer with other computations. With sufficient bandwidth between the host and device and a sufficiently long gap between model parameter updates, storing EMA weights in CPU RAM should have no additional performance overhead, as long as no other calls force synchronization.
#### Training with DREAM #### Training with DREAM
We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper. We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
## Training with LoRA ## Training with LoRA
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
......
...@@ -387,6 +387,8 @@ def parse_args(): ...@@ -387,6 +387,8 @@ def parse_args():
), ),
) )
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.")
parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.")
parser.add_argument( parser.add_argument(
"--non_ema_revision", "--non_ema_revision",
type=str, type=str,
...@@ -624,7 +626,12 @@ def main(): ...@@ -624,7 +626,12 @@ def main():
ema_unet = UNet2DConditionModel.from_pretrained( ema_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) ema_unet = EMAModel(
ema_unet.parameters(),
model_cls=UNet2DConditionModel,
model_config=ema_unet.config,
foreach=args.foreach_ema,
)
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
...@@ -655,8 +662,13 @@ def main(): ...@@ -655,8 +662,13 @@ def main():
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) load_model = EMAModel.from_pretrained(
os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
)
ema_unet.load_state_dict(load_model.state_dict()) ema_unet.load_state_dict(load_model.state_dict())
if args.offload_ema:
ema_unet.pin_memory()
else:
ema_unet.to(accelerator.device) ema_unet.to(accelerator.device)
del load_model del load_model
...@@ -833,6 +845,9 @@ def main(): ...@@ -833,6 +845,9 @@ def main():
) )
if args.use_ema: if args.use_ema:
if args.offload_ema:
ema_unet.pin_memory()
else:
ema_unet.to(accelerator.device) ema_unet.to(accelerator.device)
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
...@@ -1011,7 +1026,11 @@ def main(): ...@@ -1011,7 +1026,11 @@ def main():
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
if args.use_ema: if args.use_ema:
if args.offload_ema:
ema_unet.to(device="cuda", non_blocking=True)
ema_unet.step(unet.parameters()) ema_unet.step(unet.parameters())
if args.offload_ema:
ema_unet.to(device="cpu", non_blocking=True)
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step) accelerator.log({"train_loss": train_loss}, step=global_step)
......
...@@ -274,6 +274,7 @@ class EMAModel: ...@@ -274,6 +274,7 @@ class EMAModel:
use_ema_warmup: bool = False, use_ema_warmup: bool = False,
inv_gamma: Union[float, int] = 1.0, inv_gamma: Union[float, int] = 1.0,
power: Union[float, int] = 2 / 3, power: Union[float, int] = 2 / 3,
foreach: bool = False,
model_cls: Optional[Any] = None, model_cls: Optional[Any] = None,
model_config: Dict[str, Any] = None, model_config: Dict[str, Any] = None,
**kwargs, **kwargs,
...@@ -288,6 +289,7 @@ class EMAModel: ...@@ -288,6 +289,7 @@ class EMAModel:
inv_gamma (float): inv_gamma (float):
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
weights will be stored on CPU. weights will be stored on CPU.
...@@ -342,16 +344,17 @@ class EMAModel: ...@@ -342,16 +344,17 @@ class EMAModel:
self.power = power self.power = power
self.optimization_step = 0 self.optimization_step = 0
self.cur_decay_value = None # set in `step()` self.cur_decay_value = None # set in `step()`
self.foreach = foreach
self.model_cls = model_cls self.model_cls = model_cls
self.model_config = model_config self.model_config = model_config
@classmethod @classmethod
def from_pretrained(cls, path, model_cls) -> "EMAModel": def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
model = model_cls.from_pretrained(path) model = model_cls.from_pretrained(path)
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
ema_model.load_state_dict(ema_kwargs) ema_model.load_state_dict(ema_kwargs)
return ema_model return ema_model
...@@ -418,6 +421,28 @@ class EMAModel: ...@@ -418,6 +421,28 @@ class EMAModel:
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed import deepspeed
if self.foreach:
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
with context_manager():
params_grad = [param for param in parameters if param.requires_grad]
s_params_grad = [
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
]
if len(params_grad) < len(parameters):
torch._foreach_copy_(
[s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
[param for param in parameters if not param.requires_grad],
non_blocking=True,
)
torch._foreach_sub_(
s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
)
else:
for s_param, param in zip(self.shadow_params, parameters): for s_param, param in zip(self.shadow_params, parameters):
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
...@@ -438,10 +463,24 @@ class EMAModel: ...@@ -438,10 +463,24 @@ class EMAModel:
`ExponentialMovingAverage` was initialized will be used. `ExponentialMovingAverage` was initialized will be used.
""" """
parameters = list(parameters) parameters = list(parameters)
if self.foreach:
torch._foreach_copy_(
[param.data for param in parameters],
[s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
)
else:
for s_param, param in zip(self.shadow_params, parameters): for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.to(param.device).data) param.data.copy_(s_param.to(param.device).data)
def to(self, device=None, dtype=None) -> None: def pin_memory(self) -> None:
r"""
Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
offloading EMA params to the host.
"""
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
def to(self, device=None, dtype=None, non_blocking=False) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`. r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args: Args:
...@@ -449,7 +488,9 @@ class EMAModel: ...@@ -449,7 +488,9 @@ class EMAModel:
""" """
# .to() on the tensors handles None correctly # .to() on the tensors handles None correctly
self.shadow_params = [ self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) p.to(device=device, dtype=dtype, non_blocking=non_blocking)
if p.is_floating_point()
else p.to(device=device, non_blocking=non_blocking)
for p in self.shadow_params for p in self.shadow_params
] ]
...@@ -493,6 +534,11 @@ class EMAModel: ...@@ -493,6 +534,11 @@ class EMAModel:
""" """
if self.temp_stored_params is None: if self.temp_stored_params is None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
if self.foreach:
torch._foreach_copy_(
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
)
else:
for c_param, param in zip(self.temp_stored_params, parameters): for c_param, param in zip(self.temp_stored_params, parameters):
param.data.copy_(c_param.data) param.data.copy_(c_param.data)
......
...@@ -157,3 +157,138 @@ class EMAModelTests(unittest.TestCase): ...@@ -157,3 +157,138 @@ class EMAModelTests(unittest.TestCase):
output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample
assert torch.allclose(output, output_loaded, atol=1e-4) assert torch.allclose(output, output_loaded, atol=1e-4)
class EMAModelTestsForeach(unittest.TestCase):
model_id = "hf-internal-testing/tiny-stable-diffusion-pipe"
batch_size = 1
prompt_length = 77
text_encoder_hidden_dim = 32
num_in_channels = 4
latent_height = latent_width = 64
generator = torch.manual_seed(0)
def get_models(self, decay=0.9999):
unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet")
unet = unet.to(torch_device)
ema_unet = EMAModel(
unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=unet.config, foreach=True
)
return unet, ema_unet
def get_dummy_inputs(self):
noisy_latents = torch.randn(
self.batch_size, self.num_in_channels, self.latent_height, self.latent_width, generator=self.generator
).to(torch_device)
timesteps = torch.randint(0, 1000, size=(self.batch_size,), generator=self.generator).to(torch_device)
encoder_hidden_states = torch.randn(
self.batch_size, self.prompt_length, self.text_encoder_hidden_dim, generator=self.generator
).to(torch_device)
return noisy_latents, timesteps, encoder_hidden_states
def simulate_backprop(self, unet):
updated_state_dict = {}
for k, param in unet.state_dict().items():
updated_param = torch.randn_like(param) + (param * torch.randn_like(param))
updated_state_dict.update({k: updated_param})
unet.load_state_dict(updated_state_dict)
return unet
def test_optimization_steps_updated(self):
unet, ema_unet = self.get_models()
# Take the first (hypothetical) EMA step.
ema_unet.step(unet.parameters())
assert ema_unet.optimization_step == 1
# Take two more.
for _ in range(2):
ema_unet.step(unet.parameters())
assert ema_unet.optimization_step == 3
def test_shadow_params_not_updated(self):
unet, ema_unet = self.get_models()
# Since the `unet` is not being updated (i.e., backprop'd)
# there won't be any difference between the `params` of `unet`
# and `ema_unet` even if we call `ema_unet.step(unet.parameters())`.
ema_unet.step(unet.parameters())
orig_params = list(unet.parameters())
for s_param, param in zip(ema_unet.shadow_params, orig_params):
assert torch.allclose(s_param, param)
# The above holds true even if we call `ema.step()` multiple times since
# `unet` params are still not being updated.
for _ in range(4):
ema_unet.step(unet.parameters())
for s_param, param in zip(ema_unet.shadow_params, orig_params):
assert torch.allclose(s_param, param)
def test_shadow_params_updated(self):
unet, ema_unet = self.get_models()
# Here we simulate the parameter updates for `unet`. Since there might
# be some parameters which are initialized to zero we take extra care to
# initialize their values to something non-zero before the multiplication.
unet_pseudo_updated_step_one = self.simulate_backprop(unet)
# Take the EMA step.
ema_unet.step(unet_pseudo_updated_step_one.parameters())
# Now the EMA'd parameters won't be equal to the original model parameters.
orig_params = list(unet_pseudo_updated_step_one.parameters())
for s_param, param in zip(ema_unet.shadow_params, orig_params):
assert ~torch.allclose(s_param, param)
# Ensure this is the case when we take multiple EMA steps.
for _ in range(4):
ema_unet.step(unet.parameters())
for s_param, param in zip(ema_unet.shadow_params, orig_params):
assert ~torch.allclose(s_param, param)
def test_consecutive_shadow_params_updated(self):
# If we call EMA step after a backpropagation consecutively for two times,
# the shadow params from those two steps should be different.
unet, ema_unet = self.get_models()
# First backprop + EMA
unet_step_one = self.simulate_backprop(unet)
ema_unet.step(unet_step_one.parameters())
step_one_shadow_params = ema_unet.shadow_params
# Second backprop + EMA
unet_step_two = self.simulate_backprop(unet_step_one)
ema_unet.step(unet_step_two.parameters())
step_two_shadow_params = ema_unet.shadow_params
for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params):
assert ~torch.allclose(step_one, step_two)
def test_zero_decay(self):
# If there's no decay even if there are backprops, EMA steps
# won't take any effect i.e., the shadow params would remain the
# same.
unet, ema_unet = self.get_models(decay=0.0)
unet_step_one = self.simulate_backprop(unet)
ema_unet.step(unet_step_one.parameters())
step_one_shadow_params = ema_unet.shadow_params
unet_step_two = self.simulate_backprop(unet_step_one)
ema_unet.step(unet_step_two.parameters())
step_two_shadow_params = ema_unet.shadow_params
for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params):
assert torch.allclose(step_one, step_two)
@skip_mps
def test_serialization(self):
unet, ema_unet = self.get_models()
noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs()
with tempfile.TemporaryDirectory() as tmpdir:
ema_unet.save_pretrained(tmpdir)
loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel)
loaded_unet = loaded_unet.to(unet.device)
# Since no EMA step has been performed the outputs should match.
output = unet(noisy_latents, timesteps, encoder_hidden_states).sample
output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample
assert torch.allclose(output, output_loaded, atol=1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment