Unverified Commit 071807c8 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] feat: enable quantization for hidream lora training. (#11494)



* feat: enable quantization for hidream lora training.

* better handle compute dtype.

* finalize.

* fix dtype.

---------
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent ee1516e5
...@@ -117,3 +117,30 @@ We provide several options for optimizing memory optimization: ...@@ -117,3 +117,30 @@ We provide several options for optimizing memory optimization:
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. * `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model. Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
## Using quantization
You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:
```json
{
"load_in_4bit": true,
"bnb_4bit_quant_type": "nf4"
}
```
Below, we provide some numbers with and without the use of NF4 quantization when training:
```
(with quantization)
Memory (before device placement): 9.085089683532715 GB.
Memory (after device placement): 34.59585428237915 GB.
Memory (after backward): 36.90267467498779 GB.
(without quantization)
Memory (before device placement): 0.0 GB.
Memory (after device placement): 57.6400408744812 GB.
Memory (after backward): 59.932212829589844 GB.
```
The reason why we see some memory before device placement in the case of quantization is because, by default bnb quantized models are placed on the GPU first.
\ No newline at end of file
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import argparse import argparse
import copy import copy
import itertools import itertools
import json
import logging import logging
import math import math
import os import os
...@@ -27,14 +28,13 @@ from pathlib import Path ...@@ -27,14 +28,13 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib from huggingface_hub.utils import insecure_hashlib
from peft import LoraConfig, set_peft_model_state_dict from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
from PIL import Image from PIL import Image
from PIL.ImageOps import exif_transpose from PIL.ImageOps import exif_transpose
...@@ -47,6 +47,7 @@ from transformers import AutoTokenizer, CLIPTokenizer, LlamaForCausalLM, Pretrai ...@@ -47,6 +47,7 @@ from transformers import AutoTokenizer, CLIPTokenizer, LlamaForCausalLM, Pretrai
import diffusers import diffusers
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
BitsAndBytesConfig,
FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler,
HiDreamImagePipeline, HiDreamImagePipeline,
HiDreamImageTransformer2DModel, HiDreamImageTransformer2DModel,
...@@ -282,6 +283,12 @@ def parse_args(input_args=None): ...@@ -282,6 +283,12 @@ def parse_args(input_args=None):
default="meta-llama/Meta-Llama-3.1-8B-Instruct", default="meta-llama/Meta-Llama-3.1-8B-Instruct",
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--bnb_quantization_config_path",
type=str,
default=None,
help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
)
parser.add_argument( parser.add_argument(
"--revision", "--revision",
type=str, type=str,
...@@ -1056,6 +1063,14 @@ def main(args): ...@@ -1056,6 +1063,14 @@ def main(args):
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
) )
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Load scheduler and models # Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision, shift=3.0 args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision, shift=3.0
...@@ -1064,20 +1079,31 @@ def main(args): ...@@ -1064,20 +1079,31 @@ def main(args):
text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders( text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
) )
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
subfolder="vae", subfolder="vae",
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
) )
quantization_config = None
if args.bnb_quantization_config_path is not None:
with open(args.bnb_quantization_config_path, "r") as f:
config_kwargs = json.load(f)
if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
quantization_config = BitsAndBytesConfig(**config_kwargs)
transformer = HiDreamImageTransformer2DModel.from_pretrained( transformer = HiDreamImageTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
subfolder="transformer", subfolder="transformer",
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
quantization_config=quantization_config,
torch_dtype=weight_dtype,
force_inference_output=True, force_inference_output=True,
) )
if args.bnb_quantization_config_path is not None:
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
# We only train the additional adapter LoRA layers # We only train the additional adapter LoRA layers
transformer.requires_grad_(False) transformer.requires_grad_(False)
...@@ -1087,14 +1113,6 @@ def main(args): ...@@ -1087,14 +1113,6 @@ def main(args):
text_encoder_three.requires_grad_(False) text_encoder_three.requires_grad_(False)
text_encoder_four.requires_grad_(False) text_encoder_four.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16. # due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError( raise ValueError(
...@@ -1109,7 +1127,12 @@ def main(args): ...@@ -1109,7 +1127,12 @@ def main(args):
text_encoder_three.to(**to_kwargs) text_encoder_three.to(**to_kwargs)
text_encoder_four.to(**to_kwargs) text_encoder_four.to(**to_kwargs)
# we never offload the transformer to CPU, so we can just use the accelerator device # we never offload the transformer to CPU, so we can just use the accelerator device
transformer.to(accelerator.device, dtype=weight_dtype) transformer_to_kwargs = (
{"device": accelerator.device}
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)
# Initialize a text encoding pipeline and keep it to CPU for now. # Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = HiDreamImagePipeline.from_pretrained( text_encoding_pipeline = HiDreamImagePipeline.from_pretrained(
...@@ -1695,6 +1718,7 @@ def main(args): ...@@ -1695,6 +1718,7 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
transformer = unwrap_model(transformer) transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving: if args.upcast_before_saving:
transformer.to(torch.float32) transformer.to(torch.float32)
else: else:
......
...@@ -179,7 +179,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin): ...@@ -179,7 +179,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
This is a wrapper class about all possible attributes and features that you can play with a model that has been This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `bitsandbytes`. loaded using `bitsandbytes`.
This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive. This replaces `load_in_8bit` or `load_in_4bit` therefore both options are mutually exclusive.
Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`, Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
then more arguments will be added to this class. then more arguments will be added to this class.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment