Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
69f919d8
Unverified
Commit
69f919d8
authored
Feb 14, 2025
by
YiYi Xu
Committed by
GitHub
Feb 14, 2025
Browse files
follow-up refactor on lumina2 (#10776)
* up
parent
a6b843a7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
86 additions
and
123 deletions
+86
-123
src/diffusers/models/transformers/transformer_lumina2.py
src/diffusers/models/transformers/transformer_lumina2.py
+83
-107
src/diffusers/pipelines/lumina2/pipeline_lumina2.py
src/diffusers/pipelines/lumina2/pipeline_lumina2.py
+2
-15
tests/models/transformers/test_models_transformer_lumina2.py
tests/models/transformers/test_models_transformer_lumina2.py
+1
-1
No files found.
src/diffusers/models/transformers/transformer_lumina2.py
View file @
69f919d8
...
@@ -242,97 +242,85 @@ class Lumina2RotaryPosEmbed(nn.Module):
...
@@ -242,97 +242,85 @@ class Lumina2RotaryPosEmbed(nn.Module):
def
_precompute_freqs_cis
(
self
,
axes_dim
:
List
[
int
],
axes_lens
:
List
[
int
],
theta
:
int
)
->
List
[
torch
.
Tensor
]:
def
_precompute_freqs_cis
(
self
,
axes_dim
:
List
[
int
],
axes_lens
:
List
[
int
],
theta
:
int
)
->
List
[
torch
.
Tensor
]:
freqs_cis
=
[]
freqs_cis
=
[]
# Use float32 for MPS compatibility
freqs_dtype
=
torch
.
float32
if
torch
.
backends
.
mps
.
is_available
()
else
torch
.
float64
dtype
=
torch
.
float32
if
torch
.
backends
.
mps
.
is_available
()
else
torch
.
float64
for
i
,
(
d
,
e
)
in
enumerate
(
zip
(
axes_dim
,
axes_lens
)):
for
i
,
(
d
,
e
)
in
enumerate
(
zip
(
axes_dim
,
axes_lens
)):
emb
=
get_1d_rotary_pos_embed
(
d
,
e
,
theta
=
self
.
theta
,
freqs_dtype
=
dtype
)
emb
=
get_1d_rotary_pos_embed
(
d
,
e
,
theta
=
self
.
theta
,
freqs_dtype
=
freqs_
dtype
)
freqs_cis
.
append
(
emb
)
freqs_cis
.
append
(
emb
)
return
freqs_cis
return
freqs_cis
def
_get_freqs_cis
(
self
,
ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_get_freqs_cis
(
self
,
ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
device
=
ids
.
device
if
ids
.
device
.
type
==
"mps"
:
ids
=
ids
.
to
(
"cpu"
)
result
=
[]
result
=
[]
for
i
in
range
(
len
(
self
.
axes_dim
)):
for
i
in
range
(
len
(
self
.
axes_dim
)):
freqs
=
self
.
freqs_cis
[
i
].
to
(
ids
.
device
)
freqs
=
self
.
freqs_cis
[
i
].
to
(
ids
.
device
)
index
=
ids
[:,
:,
i
:
i
+
1
].
repeat
(
1
,
1
,
freqs
.
shape
[
-
1
]).
to
(
torch
.
int64
)
index
=
ids
[:,
:,
i
:
i
+
1
].
repeat
(
1
,
1
,
freqs
.
shape
[
-
1
]).
to
(
torch
.
int64
)
result
.
append
(
torch
.
gather
(
freqs
.
unsqueeze
(
0
).
repeat
(
index
.
shape
[
0
],
1
,
1
),
dim
=
1
,
index
=
index
))
result
.
append
(
torch
.
gather
(
freqs
.
unsqueeze
(
0
).
repeat
(
index
.
shape
[
0
],
1
,
1
),
dim
=
1
,
index
=
index
))
return
torch
.
cat
(
result
,
dim
=-
1
)
return
torch
.
cat
(
result
,
dim
=-
1
)
.
to
(
device
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
):
batch_size
=
len
(
hidden_states
)
batch_size
,
channels
,
height
,
width
=
hidden_states
.
shape
p_h
=
p_w
=
self
.
patch_size
p
=
self
.
patch_size
device
=
hidden_states
[
0
].
device
post_patch_height
,
post_patch_width
=
height
//
p
,
width
//
p
image_seq_len
=
post_patch_height
*
post_patch_width
device
=
hidden_states
.
device
encoder_seq_len
=
attention_mask
.
shape
[
1
]
l_effective_cap_len
=
attention_mask
.
sum
(
dim
=
1
).
tolist
()
l_effective_cap_len
=
attention_mask
.
sum
(
dim
=
1
).
tolist
()
# TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape
seq_lengths
=
[
cap_seq_len
+
image_seq_len
for
cap_seq_len
in
l_effective_cap_len
]
img_sizes
=
[(
img
.
size
(
1
),
img
.
size
(
2
))
for
img
in
hidden_states
]
max_seq_len
=
max
(
seq_lengths
)
l_effective_img_len
=
[(
H
//
p_h
)
*
(
W
//
p_w
)
for
(
H
,
W
)
in
img_sizes
]
max_seq_len
=
max
((
cap_len
+
img_len
for
cap_len
,
img_len
in
zip
(
l_effective_cap_len
,
l_effective_img_len
)))
max_img_len
=
max
(
l_effective_img_len
)
# Create position IDs
position_ids
=
torch
.
zeros
(
batch_size
,
max_seq_len
,
3
,
dtype
=
torch
.
int32
,
device
=
device
)
position_ids
=
torch
.
zeros
(
batch_size
,
max_seq_len
,
3
,
dtype
=
torch
.
int32
,
device
=
device
)
for
i
in
range
(
batch_size
):
for
i
,
(
cap_seq_len
,
seq_len
)
in
enumerate
(
zip
(
l_effective_cap_len
,
seq_lengths
)):
cap_len
=
l_effective_cap_len
[
i
]
# add caption position ids
img_len
=
l_effective_img_len
[
i
]
position_ids
[
i
,
:
cap_seq_len
,
0
]
=
torch
.
arange
(
cap_seq_len
,
dtype
=
torch
.
int32
,
device
=
device
)
H
,
W
=
img_sizes
[
i
]
position_ids
[
i
,
cap_seq_len
:
seq_len
,
0
]
=
cap_seq_len
H_tokens
,
W_tokens
=
H
//
p_h
,
W
//
p_w
assert
H_tokens
*
W_tokens
==
img_len
position_ids
[
i
,
:
cap_len
,
0
]
=
torch
.
arange
(
cap_len
,
dtype
=
torch
.
int32
,
device
=
device
)
# add image position ids
position_ids
[
i
,
cap_len
:
cap_len
+
img_len
,
0
]
=
cap_len
row_ids
=
(
row_ids
=
(
torch
.
arange
(
H_tokens
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
-
1
,
1
).
repeat
(
1
,
W_tokens
).
flatten
()
torch
.
arange
(
post_patch_height
,
dtype
=
torch
.
int32
,
device
=
device
)
.
view
(
-
1
,
1
)
.
repeat
(
1
,
post_patch_width
)
.
flatten
()
)
)
col_ids
=
(
col_ids
=
(
torch
.
arange
(
W_tokens
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
1
,
-
1
).
repeat
(
H_tokens
,
1
).
flatten
()
torch
.
arange
(
post_patch_width
,
dtype
=
torch
.
int32
,
device
=
device
)
.
view
(
1
,
-
1
)
.
repeat
(
post_patch_height
,
1
)
.
flatten
()
)
)
position_ids
[
i
,
cap_
len
:
cap_len
+
img
_len
,
1
]
=
row_ids
position_ids
[
i
,
cap_
seq_len
:
seq
_len
,
1
]
=
row_ids
position_ids
[
i
,
cap_
len
:
cap_len
+
img
_len
,
2
]
=
col_ids
position_ids
[
i
,
cap_
seq_len
:
seq
_len
,
2
]
=
col_ids
# Get combined rotary embeddings
freqs_cis
=
self
.
_get_freqs_cis
(
position_ids
)
freqs_cis
=
self
.
_get_freqs_cis
(
position_ids
)
cap_freqs_cis_shape
=
list
(
freqs_cis
.
shape
)
# create separate rotary embeddings for captions and images
cap_freqs_cis_shape
[
1
]
=
attention_mask
.
shape
[
1
]
cap_freqs_cis
=
torch
.
zeros
(
cap_freqs_cis
=
torch
.
zeros
(
*
cap_freqs_cis_shape
,
device
=
device
,
dtype
=
freqs_cis
.
dtype
)
batch_size
,
encoder_seq_len
,
freqs_cis
.
shape
[
-
1
],
device
=
device
,
dtype
=
freqs_cis
.
dtype
img_freqs_cis_shape
=
list
(
freqs_cis
.
shape
)
img_freqs_cis_shape
[
1
]
=
max_img_len
img_freqs_cis
=
torch
.
zeros
(
*
img_freqs_cis_shape
,
device
=
device
,
dtype
=
freqs_cis
.
dtype
)
for
i
in
range
(
batch_size
):
cap_len
=
l_effective_cap_len
[
i
]
img_len
=
l_effective_img_len
[
i
]
cap_freqs_cis
[
i
,
:
cap_len
]
=
freqs_cis
[
i
,
:
cap_len
]
img_freqs_cis
[
i
,
:
img_len
]
=
freqs_cis
[
i
,
cap_len
:
cap_len
+
img_len
]
flat_hidden_states
=
[]
for
i
in
range
(
batch_size
):
img
=
hidden_states
[
i
]
C
,
H
,
W
=
img
.
size
()
img
=
img
.
view
(
C
,
H
//
p_h
,
p_h
,
W
//
p_w
,
p_w
).
permute
(
1
,
3
,
2
,
4
,
0
).
flatten
(
2
).
flatten
(
0
,
1
)
flat_hidden_states
.
append
(
img
)
hidden_states
=
flat_hidden_states
padded_img_embed
=
torch
.
zeros
(
batch_size
,
max_img_len
,
hidden_states
[
0
].
shape
[
-
1
],
device
=
device
,
dtype
=
hidden_states
[
0
].
dtype
)
)
padded_img_mask
=
torch
.
zeros
(
batch_size
,
max_img_len
,
dtype
=
torch
.
bool
,
device
=
device
)
img_freqs_cis
=
torch
.
zeros
(
for
i
in
range
(
batch_size
):
batch_size
,
image_seq_len
,
freqs_cis
.
shape
[
-
1
],
device
=
device
,
dtype
=
freqs_cis
.
dtype
padded_img_embed
[
i
,
:
l_effective_img_len
[
i
]]
=
hidden_states
[
i
]
)
padded_img_mask
[
i
,
:
l_effective_img_len
[
i
]]
=
True
for
i
,
(
cap_seq_len
,
seq_len
)
in
enumerate
(
zip
(
l_effective_cap_len
,
seq_lengths
)):
return
(
cap_freqs_cis
[
i
,
:
cap_seq_len
]
=
freqs_cis
[
i
,
:
cap_seq_len
]
padded_img_embed
,
img_freqs_cis
[
i
,
:
image_seq_len
]
=
freqs_cis
[
i
,
cap_seq_len
:
seq_len
]
padded_img_mask
,
img_sizes
,
# image patch embeddings
l_effective_cap_len
,
hidden_states
=
(
l_effective_img_len
,
hidden_states
.
view
(
batch_size
,
channels
,
post_patch_height
,
p
,
post_patch_width
,
p
)
freqs_cis
,
.
permute
(
0
,
2
,
4
,
3
,
5
,
1
)
cap_freqs_cis
,
.
flatten
(
3
)
img_freqs_cis
,
.
flatten
(
1
,
2
)
max_seq_len
,
)
)
return
hidden_states
,
cap_freqs_cis
,
img_freqs_cis
,
freqs_cis
,
l_effective_cap_len
,
seq_lengths
class
Lumina2Transformer2DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
,
FromOriginalModelMixin
):
class
Lumina2Transformer2DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
,
FromOriginalModelMixin
):
r
"""
r
"""
...
@@ -472,75 +460,63 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
...
@@ -472,75 +460,63 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
timestep
:
torch
.
Tensor
,
timestep
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
encoder_attention_mask
:
torch
.
Tensor
,
use_mask_in_transformer
:
bool
=
True
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
)
->
Union
[
torch
.
Tensor
,
Transformer2DModelOutput
]:
)
->
Union
[
torch
.
Tensor
,
Transformer2DModelOutput
]:
batch_size
=
hidden_states
.
size
(
0
)
# 1. Condition, positional & patch embedding
# 1. Condition, positional & patch embedding
batch_size
,
_
,
height
,
width
=
hidden_states
.
shape
temb
,
encoder_hidden_states
=
self
.
time_caption_embed
(
hidden_states
,
timestep
,
encoder_hidden_states
)
temb
,
encoder_hidden_states
=
self
.
time_caption_embed
(
hidden_states
,
timestep
,
encoder_hidden_states
)
(
(
hidden_states
,
hidden_states
,
hidden_mask
,
context_rotary_emb
,
hidden_sizes
,
noise_rotary_emb
,
encoder_hidden_len
,
rotary_emb
,
hidden_len
,
encoder_seq_lengths
,
joint_rotary_emb
,
seq_lengths
,
encoder_rotary_emb
,
)
=
self
.
rope_embedder
(
hidden_states
,
encoder_attention_mask
)
hidden_rotary_emb
,
max_seq_len
,
)
=
self
.
rope_embedder
(
hidden_states
,
attention_mask
)
hidden_states
=
self
.
x_embedder
(
hidden_states
)
hidden_states
=
self
.
x_embedder
(
hidden_states
)
# 2. Context & noise refinement
# 2. Context & noise refinement
for
layer
in
self
.
context_refiner
:
for
layer
in
self
.
context_refiner
:
# NOTE: mask not used for performance
encoder_hidden_states
=
layer
(
encoder_hidden_states
,
encoder_attention_mask
,
context_rotary_emb
)
encoder_hidden_states
=
layer
(
encoder_hidden_states
,
attention_mask
if
use_mask_in_transformer
else
None
,
encoder_rotary_emb
)
for
layer
in
self
.
noise_refiner
:
for
layer
in
self
.
noise_refiner
:
# NOTE: mask not used for performance
hidden_states
=
layer
(
hidden_states
,
None
,
noise_rotary_emb
,
temb
)
hidden_states
=
layer
(
hidden_states
,
hidden_mask
if
use_mask_in_transformer
else
None
,
hidden_rotary_emb
,
temb
# 3. Joint Transformer blocks
)
max_seq_len
=
max
(
seq_lengths
)
use_mask
=
len
(
set
(
seq_lengths
))
>
1
attention_mask
=
hidden_states
.
new_zeros
(
batch_size
,
max_seq_len
,
dtype
=
torch
.
bool
)
joint_hidden_states
=
hidden_states
.
new_zeros
(
batch_size
,
max_seq_len
,
self
.
config
.
hidden_size
)
for
i
,
(
encoder_seq_len
,
seq_len
)
in
enumerate
(
zip
(
encoder_seq_lengths
,
seq_lengths
)):
attention_mask
[
i
,
:
seq_len
]
=
True
joint_hidden_states
[
i
,
:
encoder_seq_len
]
=
encoder_hidden_states
[
i
,
:
encoder_seq_len
]
joint_hidden_states
[
i
,
encoder_seq_len
:
seq_len
]
=
hidden_states
[
i
]
hidden_states
=
joint_hidden_states
# 3. Attention mask preparation
mask
=
hidden_states
.
new_zeros
(
batch_size
,
max_seq_len
,
dtype
=
torch
.
bool
)
padded_hidden_states
=
hidden_states
.
new_zeros
(
batch_size
,
max_seq_len
,
self
.
config
.
hidden_size
)
for
i
in
range
(
batch_size
):
cap_len
=
encoder_hidden_len
[
i
]
img_len
=
hidden_len
[
i
]
mask
[
i
,
:
cap_len
+
img_len
]
=
True
padded_hidden_states
[
i
,
:
cap_len
]
=
encoder_hidden_states
[
i
,
:
cap_len
]
padded_hidden_states
[
i
,
cap_len
:
cap_len
+
img_len
]
=
hidden_states
[
i
,
:
img_len
]
hidden_states
=
padded_hidden_states
# 4. Transformer blocks
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
# NOTE: mask not used for performance
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
hidden_states
=
self
.
_gradient_checkpointing_func
(
hidden_states
=
self
.
_gradient_checkpointing_func
(
layer
,
hidden_states
,
mask
if
use_mask
_in_transformer
else
None
,
joint_
rotary_emb
,
temb
layer
,
hidden_states
,
attention_
mask
if
use_mask
else
None
,
rotary_emb
,
temb
)
)
else
:
else
:
hidden_states
=
layer
(
hidden_states
,
mask
if
use_mask
_in_transformer
else
None
,
joint_
rotary_emb
,
temb
)
hidden_states
=
layer
(
hidden_states
,
attention_
mask
if
use_mask
else
None
,
rotary_emb
,
temb
)
#
5
. Output norm & projection
& unpatchify
#
4
. Output norm & projection
hidden_states
=
self
.
norm_out
(
hidden_states
,
temb
)
hidden_states
=
self
.
norm_out
(
hidden_states
,
temb
)
height_tokens
=
width_tokens
=
self
.
config
.
patch_size
# 5. Unpatchify
p
=
self
.
config
.
patch_size
output
=
[]
output
=
[]
for
i
in
range
(
len
(
hidden_sizes
)):
for
i
,
(
encoder_seq_len
,
seq_len
)
in
enumerate
(
zip
(
encoder_seq_lengths
,
seq_lengths
)):
height
,
width
=
hidden_sizes
[
i
]
begin
=
encoder_hidden_len
[
i
]
end
=
begin
+
(
height
//
height_tokens
)
*
(
width
//
width_tokens
)
output
.
append
(
output
.
append
(
hidden_states
[
i
][
begin
:
en
d
]
hidden_states
[
i
][
encoder_seq_len
:
seq_l
en
]
.
view
(
height
//
height_tokens
,
width
//
width_tokens
,
height_tokens
,
width_tokens
,
self
.
out_channels
)
.
view
(
height
//
p
,
width
//
p
,
p
,
p
,
self
.
out_channels
)
.
permute
(
4
,
0
,
2
,
1
,
3
)
.
permute
(
4
,
0
,
2
,
1
,
3
)
.
flatten
(
3
,
4
)
.
flatten
(
3
,
4
)
.
flatten
(
1
,
2
)
.
flatten
(
1
,
2
)
...
...
src/diffusers/pipelines/lumina2/pipeline_lumina2.py
View file @
69f919d8
...
@@ -24,8 +24,6 @@ from ...models import AutoencoderKL
...
@@ -24,8 +24,6 @@ from ...models import AutoencoderKL
from
...models.transformers.transformer_lumina2
import
Lumina2Transformer2DModel
from
...models.transformers.transformer_lumina2
import
Lumina2Transformer2DModel
from
...schedulers
import
FlowMatchEulerDiscreteScheduler
from
...schedulers
import
FlowMatchEulerDiscreteScheduler
from
...utils
import
(
from
...utils
import
(
is_bs4_available
,
is_ftfy_available
,
is_torch_xla_available
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
...
@@ -44,12 +42,6 @@ else:
...
@@ -44,12 +42,6 @@ else:
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_bs4_available
():
pass
if
is_ftfy_available
():
pass
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -527,7 +519,6 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
...
@@ -527,7 +519,6 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
system_prompt
:
Optional
[
str
]
=
None
,
system_prompt
:
Optional
[
str
]
=
None
,
cfg_trunc_ratio
:
float
=
1.0
,
cfg_trunc_ratio
:
float
=
1.0
,
cfg_normalization
:
bool
=
True
,
cfg_normalization
:
bool
=
True
,
use_mask_in_transformer
:
bool
=
True
,
max_sequence_length
:
int
=
256
,
max_sequence_length
:
int
=
256
,
)
->
Union
[
ImagePipelineOutput
,
Tuple
]:
)
->
Union
[
ImagePipelineOutput
,
Tuple
]:
"""
"""
...
@@ -599,8 +590,6 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
...
@@ -599,8 +590,6 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
The ratio of the timestep interval to apply normalization-based guidance scale.
The ratio of the timestep interval to apply normalization-based guidance scale.
cfg_normalization (`bool`, *optional*, defaults to `True`):
cfg_normalization (`bool`, *optional*, defaults to `True`):
Whether to apply normalization-based guidance scale.
Whether to apply normalization-based guidance scale.
use_mask_in_transformer (`bool`, *optional*, defaults to `True`):
Whether to use attention mask in `Lumina2Transformer2DModel`. Set `False` for performance gain.
max_sequence_length (`int`, defaults to `256`):
max_sequence_length (`int`, defaults to `256`):
Maximum sequence length to use with the `prompt`.
Maximum sequence length to use with the `prompt`.
...
@@ -706,8 +695,7 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
...
@@ -706,8 +695,7 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
hidden_states
=
latents
,
hidden_states
=
latents
,
timestep
=
current_timestep
,
timestep
=
current_timestep
,
encoder_hidden_states
=
prompt_embeds
,
encoder_hidden_states
=
prompt_embeds
,
attention_mask
=
prompt_attention_mask
,
encoder_attention_mask
=
prompt_attention_mask
,
use_mask_in_transformer
=
use_mask_in_transformer
,
return_dict
=
False
,
return_dict
=
False
,
)[
0
]
)[
0
]
...
@@ -717,8 +705,7 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
...
@@ -717,8 +705,7 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
hidden_states
=
latents
,
hidden_states
=
latents
,
timestep
=
current_timestep
,
timestep
=
current_timestep
,
encoder_hidden_states
=
negative_prompt_embeds
,
encoder_hidden_states
=
negative_prompt_embeds
,
attention_mask
=
negative_prompt_attention_mask
,
encoder_attention_mask
=
negative_prompt_attention_mask
,
use_mask_in_transformer
=
use_mask_in_transformer
,
return_dict
=
False
,
return_dict
=
False
,
)[
0
]
)[
0
]
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
...
...
tests/models/transformers/test_models_transformer_lumina2.py
View file @
69f919d8
...
@@ -51,7 +51,7 @@ class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestC
...
@@ -51,7 +51,7 @@ class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestC
"hidden_states"
:
hidden_states
,
"hidden_states"
:
hidden_states
,
"encoder_hidden_states"
:
encoder_hidden_states
,
"encoder_hidden_states"
:
encoder_hidden_states
,
"timestep"
:
timestep
,
"timestep"
:
timestep
,
"attention_mask"
:
attention_mask
,
"
encoder_
attention_mask"
:
attention_mask
,
}
}
@
property
@
property
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment