Unverified Commit 2f6f426f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Hunyuan] allow Hunyuan DiT to run under 6GB for GPU VRAM (#8399)

* allow hunyuan dit to run under 6GB for GPU VRAM

* add section in the docs/
parent a0542c19
......@@ -29,6 +29,10 @@ HunyuanDiT has the following components:
* It combines two text encoders, a bilingual CLIP and a multilingual T5 encoder
## Memory optimization
By loading the T5 text encoder in 8 bits, you can run the pipeline in just under 6 GBs of GPU VRAM. Refer to [this script](https://gist.github.com/sayakpaul/3154605f6af05b98a41081aaba5ca43e) for details.
## HunyuanDiTPipeline
[[autodoc]] HunyuanDiTPipeline
......
......@@ -228,16 +228,22 @@ class HunyuanDiTPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = self.transformer.config.sample_size
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer") and self.transformer is not None
else 128
)
def encode_prompt(
self,
prompt: str,
device: torch.device,
dtype: torch.dtype,
device: torch.device = None,
dtype: torch.dtype = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
......@@ -281,6 +287,17 @@ class HunyuanDiTPipeline(DiffusionPipeline):
text_encoder_index (`int`, *optional*):
Index of the text encoder to use. `0` for clip and `1` for T5.
"""
if dtype is None:
if self.text_encoder_2 is not None:
dtype = self.text_encoder_2.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
if device is None:
device = self._execution_device
tokenizers = [self.tokenizer, self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment