Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4267d8f4
Unverified
Commit
4267d8f4
authored
May 15, 2025
by
Dhruv Nair
Committed by
GitHub
May 15, 2025
Browse files
[Single File] GGUF/Single File Support for HiDream (#11550)
* update * update * update * update * update * update * update
parent
f4fa3bee
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
67 additions
and
5 deletions
+67
-5
docs/source/en/api/models/hidream_image_transformer.md
docs/source/en/api/models/hidream_image_transformer.md
+16
-0
src/diffusers/loaders/single_file_model.py
src/diffusers/loaders/single_file_model.py
+5
-0
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+13
-0
src/diffusers/models/transformers/transformer_hidream_image.py
...iffusers/models/transformers/transformer_hidream_image.py
+2
-2
src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
...ffusers/pipelines/hidream_image/pipeline_hidream_image.py
+3
-3
tests/quantization/gguf/test_gguf.py
tests/quantization/gguf/test_gguf.py
+28
-0
No files found.
docs/source/en/api/models/hidream_image_transformer.md
View file @
4267d8f4
...
@@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel
...
@@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel
transformer
=
HiDreamImageTransformer2DModel
.
from_pretrained
(
"HiDream-ai/HiDream-I1-Full"
,
subfolder
=
"transformer"
,
torch_dtype
=
torch
.
bfloat16
)
transformer
=
HiDreamImageTransformer2DModel
.
from_pretrained
(
"HiDream-ai/HiDream-I1-Full"
,
subfolder
=
"transformer"
,
torch_dtype
=
torch
.
bfloat16
)
```
```
## Loading GGUF quantized checkpoints for HiDream-I1
GGUF checkpoints for the
`HiDreamImageTransformer2DModel`
can be loaded using
`~FromOriginalModelMixin.from_single_file`
```
python
import
torch
from
diffusers
import
GGUFQuantizationConfig
,
HiDreamImageTransformer2DModel
ckpt_path
=
"https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
transformer
=
HiDreamImageTransformer2DModel
.
from_single_file
(
ckpt_path
,
quantization_config
=
GGUFQuantizationConfig
(
compute_dtype
=
torch
.
bfloat16
),
torch_dtype
=
torch
.
bfloat16
)
```
## HiDreamImageTransformer2DModel
## HiDreamImageTransformer2DModel
[[autodoc]] HiDreamImageTransformer2DModel
[[autodoc]] HiDreamImageTransformer2DModel
...
...
src/diffusers/loaders/single_file_model.py
View file @
4267d8f4
...
@@ -31,6 +31,7 @@ from .single_file_utils import (
...
@@ -31,6 +31,7 @@ from .single_file_utils import (
convert_autoencoder_dc_checkpoint_to_diffusers
,
convert_autoencoder_dc_checkpoint_to_diffusers
,
convert_controlnet_checkpoint
,
convert_controlnet_checkpoint
,
convert_flux_transformer_checkpoint_to_diffusers
,
convert_flux_transformer_checkpoint_to_diffusers
,
convert_hidream_transformer_to_diffusers
,
convert_hunyuan_video_transformer_to_diffusers
,
convert_hunyuan_video_transformer_to_diffusers
,
convert_ldm_unet_checkpoint
,
convert_ldm_unet_checkpoint
,
convert_ldm_vae_checkpoint
,
convert_ldm_vae_checkpoint
,
...
@@ -133,6 +134,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
...
@@ -133,6 +134,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn"
:
convert_wan_vae_to_diffusers
,
"checkpoint_mapping_fn"
:
convert_wan_vae_to_diffusers
,
"default_subfolder"
:
"vae"
,
"default_subfolder"
:
"vae"
,
},
},
"HiDreamImageTransformer2DModel"
:
{
"checkpoint_mapping_fn"
:
convert_hidream_transformer_to_diffusers
,
"default_subfolder"
:
"transformer"
,
},
}
}
...
...
src/diffusers/loaders/single_file_utils.py
View file @
4267d8f4
...
@@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = {
...
@@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = {
],
],
"wan"
:
[
"model.diffusion_model.head.modulation"
,
"head.modulation"
],
"wan"
:
[
"model.diffusion_model.head.modulation"
,
"head.modulation"
],
"wan_vae"
:
"decoder.middle.0.residual.0.gamma"
,
"wan_vae"
:
"decoder.middle.0.residual.0.gamma"
,
"hidream"
:
"double_stream_blocks.0.block.adaLN_modulation.1.bias"
,
}
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS
=
{
DIFFUSERS_DEFAULT_PIPELINE_PATHS
=
{
...
@@ -190,6 +191,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
...
@@ -190,6 +191,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"wan-t2v-1.3B"
:
{
"pretrained_model_name_or_path"
:
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
},
"wan-t2v-1.3B"
:
{
"pretrained_model_name_or_path"
:
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
},
"wan-t2v-14B"
:
{
"pretrained_model_name_or_path"
:
"Wan-AI/Wan2.1-T2V-14B-Diffusers"
},
"wan-t2v-14B"
:
{
"pretrained_model_name_or_path"
:
"Wan-AI/Wan2.1-T2V-14B-Diffusers"
},
"wan-i2v-14B"
:
{
"pretrained_model_name_or_path"
:
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
},
"wan-i2v-14B"
:
{
"pretrained_model_name_or_path"
:
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
},
"hidream"
:
{
"pretrained_model_name_or_path"
:
"HiDream-ai/HiDream-I1-Dev"
},
}
}
# Use to configure model sample size when original config is provided
# Use to configure model sample size when original config is provided
...
@@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
...
@@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
elif
CHECKPOINT_KEY_NAMES
[
"wan_vae"
]
in
checkpoint
:
elif
CHECKPOINT_KEY_NAMES
[
"wan_vae"
]
in
checkpoint
:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type
=
"wan-t2v-14B"
model_type
=
"wan-t2v-14B"
elif
CHECKPOINT_KEY_NAMES
[
"hidream"
]
in
checkpoint
:
model_type
=
"hidream"
else
:
else
:
model_type
=
"v1"
model_type
=
"v1"
...
@@ -3293,3 +3297,12 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
...
@@ -3293,3 +3297,12 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
converted_state_dict
[
key
]
=
value
converted_state_dict
[
key
]
=
value
return
converted_state_dict
return
converted_state_dict
def
convert_hidream_transformer_to_diffusers
(
checkpoint
,
**
kwargs
):
keys
=
list
(
checkpoint
.
keys
())
for
k
in
keys
:
if
"model.diffusion_model."
in
k
:
checkpoint
[
k
.
replace
(
"model.diffusion_model."
,
""
)]
=
checkpoint
.
pop
(
k
)
return
checkpoint
src/diffusers/models/transformers/transformer_hidream_image.py
View file @
4267d8f4
...
@@ -5,7 +5,7 @@ import torch.nn as nn
...
@@ -5,7 +5,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...loaders
import
PeftAdapterMixin
from
...loaders
import
FromOriginalModelMixin
,
PeftAdapterMixin
from
...models.modeling_outputs
import
Transformer2DModelOutput
from
...models.modeling_outputs
import
Transformer2DModelOutput
from
...models.modeling_utils
import
ModelMixin
from
...models.modeling_utils
import
ModelMixin
from
...utils
import
USE_PEFT_BACKEND
,
deprecate
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
USE_PEFT_BACKEND
,
deprecate
,
logging
,
scale_lora_layers
,
unscale_lora_layers
...
@@ -602,7 +602,7 @@ class HiDreamBlock(nn.Module):
...
@@ -602,7 +602,7 @@ class HiDreamBlock(nn.Module):
)
)
class
HiDreamImageTransformer2DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
):
class
HiDreamImageTransformer2DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
,
FromOriginalModelMixin
):
_supports_gradient_checkpointing
=
True
_supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"HiDreamImageTransformerBlock"
,
"HiDreamImageSingleTransformerBlock"
]
_no_split_modules
=
[
"HiDreamImageTransformerBlock"
,
"HiDreamImageSingleTransformerBlock"
]
...
...
src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
View file @
4267d8f4
...
@@ -36,11 +36,11 @@ EXAMPLE_DOC_STRING = """
...
@@ -36,11 +36,11 @@ EXAMPLE_DOC_STRING = """
Examples:
Examples:
```py
```py
>>> import torch
>>> import torch
>>> from transformers import
PreTrained
Tokenizer
Fast
, LlamaForCausalLM
>>> from transformers import
Auto
Tokenizer, LlamaForCausalLM
>>> from diffusers import
UniPCMultistepScheduler,
HiDreamImagePipeline
>>> from diffusers import HiDreamImagePipeline
>>> tokenizer_4 =
PreTrained
Tokenizer
Fast
.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> tokenizer_4 =
Auto
Tokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
... output_hidden_states=True,
... output_hidden_states=True,
...
...
tests/quantization/gguf/test_gguf.py
View file @
4267d8f4
...
@@ -12,6 +12,7 @@ from diffusers import (
...
@@ -12,6 +12,7 @@ from diffusers import (
FluxPipeline
,
FluxPipeline
,
FluxTransformer2DModel
,
FluxTransformer2DModel
,
GGUFQuantizationConfig
,
GGUFQuantizationConfig
,
HiDreamImageTransformer2DModel
,
SD3Transformer2DModel
,
SD3Transformer2DModel
,
StableDiffusion3Pipeline
,
StableDiffusion3Pipeline
,
)
)
...
@@ -549,3 +550,30 @@ class FluxControlLoRAGGUFTests(unittest.TestCase):
...
@@ -549,3 +550,30 @@ class FluxControlLoRAGGUFTests(unittest.TestCase):
max_diff
=
numpy_cosine_similarity_distance
(
expected_slice
,
out_slice
)
max_diff
=
numpy_cosine_similarity_distance
(
expected_slice
,
out_slice
)
self
.
assertTrue
(
max_diff
<
1e-3
)
self
.
assertTrue
(
max_diff
<
1e-3
)
class
HiDreamGGUFSingleFileTests
(
GGUFSingleFileTesterMixin
,
unittest
.
TestCase
):
ckpt_path
=
"https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
torch_dtype
=
torch
.
bfloat16
model_cls
=
HiDreamImageTransformer2DModel
expected_memory_use_in_gb
=
8
def
get_dummy_inputs
(
self
):
return
{
"hidden_states"
:
torch
.
randn
((
1
,
16
,
128
,
128
),
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
)).
to
(
torch_device
,
self
.
torch_dtype
),
"encoder_hidden_states_t5"
:
torch
.
randn
(
(
1
,
128
,
4096
),
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
),
).
to
(
torch_device
,
self
.
torch_dtype
),
"encoder_hidden_states_llama3"
:
torch
.
randn
(
(
32
,
1
,
128
,
4096
),
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
),
).
to
(
torch_device
,
self
.
torch_dtype
),
"pooled_embeds"
:
torch
.
randn
(
(
1
,
2048
),
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
),
).
to
(
torch_device
,
self
.
torch_dtype
),
"timesteps"
:
torch
.
tensor
([
1
]).
to
(
torch_device
,
self
.
torch_dtype
),
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment