Unverified Commit 2f6f426f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Hunyuan] allow Hunyuan DiT to run under 6GB for GPU VRAM (#8399)

* allow hunyuan dit to run under 6GB for GPU VRAM

* add section in the docs/
parent a0542c19
...@@ -29,6 +29,10 @@ HunyuanDiT has the following components: ...@@ -29,6 +29,10 @@ HunyuanDiT has the following components:
* It combines two text encoders, a bilingual CLIP and a multilingual T5 encoder * It combines two text encoders, a bilingual CLIP and a multilingual T5 encoder
## Memory optimization
By loading the T5 text encoder in 8 bits, you can run the pipeline in just under 6 GBs of GPU VRAM. Refer to [this script](https://gist.github.com/sayakpaul/3154605f6af05b98a41081aaba5ca43e) for details.
## HunyuanDiTPipeline ## HunyuanDiTPipeline
[[autodoc]] HunyuanDiTPipeline [[autodoc]] HunyuanDiTPipeline
......
...@@ -228,16 +228,22 @@ class HunyuanDiTPipeline(DiffusionPipeline): ...@@ -228,16 +228,22 @@ class HunyuanDiTPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = self.transformer.config.sample_size self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer") and self.transformer is not None
else 128
)
def encode_prompt( def encode_prompt(
self, self,
prompt: str, prompt: str,
device: torch.device, device: torch.device = None,
dtype: torch.dtype, dtype: torch.dtype = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None, negative_prompt: Optional[str] = None,
...@@ -281,6 +287,17 @@ class HunyuanDiTPipeline(DiffusionPipeline): ...@@ -281,6 +287,17 @@ class HunyuanDiTPipeline(DiffusionPipeline):
text_encoder_index (`int`, *optional*): text_encoder_index (`int`, *optional*):
Index of the text encoder to use. `0` for clip and `1` for T5. Index of the text encoder to use. `0` for clip and `1` for T5.
""" """
if dtype is None:
if self.text_encoder_2 is not None:
dtype = self.text_encoder_2.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
if device is None:
device = self._execution_device
tokenizers = [self.tokenizer, self.tokenizer_2] tokenizers = [self.tokenizer, self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2] text_encoders = [self.text_encoder, self.text_encoder_2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment