Unverified Commit 69f919d8 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

follow-up refactor on lumina2 (#10776)

* up
parent a6b843a7
...@@ -242,97 +242,85 @@ class Lumina2RotaryPosEmbed(nn.Module): ...@@ -242,97 +242,85 @@ class Lumina2RotaryPosEmbed(nn.Module):
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
freqs_cis = [] freqs_cis = []
# Use float32 for MPS compatibility freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=dtype) emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
freqs_cis.append(emb) freqs_cis.append(emb)
return freqs_cis return freqs_cis
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
device = ids.device
if ids.device.type == "mps":
ids = ids.to("cpu")
result = [] result = []
for i in range(len(self.axes_dim)): for i in range(len(self.axes_dim)):
freqs = self.freqs_cis[i].to(ids.device) freqs = self.freqs_cis[i].to(ids.device)
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
return torch.cat(result, dim=-1) return torch.cat(result, dim=-1).to(device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
batch_size = len(hidden_states) batch_size, channels, height, width = hidden_states.shape
p_h = p_w = self.patch_size p = self.patch_size
device = hidden_states[0].device post_patch_height, post_patch_width = height // p, width // p
image_seq_len = post_patch_height * post_patch_width
device = hidden_states.device
encoder_seq_len = attention_mask.shape[1]
l_effective_cap_len = attention_mask.sum(dim=1).tolist() l_effective_cap_len = attention_mask.sum(dim=1).tolist()
# TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] max_seq_len = max(seq_lengths)
l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes]
max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)))
max_img_len = max(l_effective_img_len)
# Create position IDs
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
for i in range(batch_size): for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
cap_len = l_effective_cap_len[i] # add caption position ids
img_len = l_effective_img_len[i] position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
H, W = img_sizes[i] position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
H_tokens, W_tokens = H // p_h, W // p_w
assert H_tokens * W_tokens == img_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) # add image position ids
position_ids[i, cap_len : cap_len + img_len, 0] = cap_len
row_ids = ( row_ids = (
torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() torch.arange(post_patch_height, dtype=torch.int32, device=device)
.view(-1, 1)
.repeat(1, post_patch_width)
.flatten()
) )
col_ids = ( col_ids = (
torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() torch.arange(post_patch_width, dtype=torch.int32, device=device)
.view(1, -1)
.repeat(post_patch_height, 1)
.flatten()
) )
position_ids[i, cap_len : cap_len + img_len, 1] = row_ids position_ids[i, cap_seq_len:seq_len, 1] = row_ids
position_ids[i, cap_len : cap_len + img_len, 2] = col_ids position_ids[i, cap_seq_len:seq_len, 2] = col_ids
# Get combined rotary embeddings
freqs_cis = self._get_freqs_cis(position_ids) freqs_cis = self._get_freqs_cis(position_ids)
cap_freqs_cis_shape = list(freqs_cis.shape) # create separate rotary embeddings for captions and images
cap_freqs_cis_shape[1] = attention_mask.shape[1] cap_freqs_cis = torch.zeros(
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(batch_size):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len]
flat_hidden_states = []
for i in range(batch_size):
img = hidden_states[i]
C, H, W = img.size()
img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_hidden_states.append(img)
hidden_states = flat_hidden_states
padded_img_embed = torch.zeros(
batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype
) )
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) img_freqs_cis = torch.zeros(
for i in range(batch_size): batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i] )
padded_img_mask[i, : l_effective_img_len[i]] = True
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
return ( cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
padded_img_embed, img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
padded_img_mask,
img_sizes, # image patch embeddings
l_effective_cap_len, hidden_states = (
l_effective_img_len, hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
freqs_cis, .permute(0, 2, 4, 3, 5, 1)
cap_freqs_cis, .flatten(3)
img_freqs_cis, .flatten(1, 2)
max_seq_len,
) )
return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r""" r"""
...@@ -472,75 +460,63 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO ...@@ -472,75 +460,63 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
timestep: torch.Tensor, timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor, encoder_attention_mask: torch.Tensor,
use_mask_in_transformer: bool = True,
return_dict: bool = True, return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]: ) -> Union[torch.Tensor, Transformer2DModelOutput]:
batch_size = hidden_states.size(0)
# 1. Condition, positional & patch embedding # 1. Condition, positional & patch embedding
batch_size, _, height, width = hidden_states.shape
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
( (
hidden_states, hidden_states,
hidden_mask, context_rotary_emb,
hidden_sizes, noise_rotary_emb,
encoder_hidden_len, rotary_emb,
hidden_len, encoder_seq_lengths,
joint_rotary_emb, seq_lengths,
encoder_rotary_emb, ) = self.rope_embedder(hidden_states, encoder_attention_mask)
hidden_rotary_emb,
max_seq_len,
) = self.rope_embedder(hidden_states, attention_mask)
hidden_states = self.x_embedder(hidden_states) hidden_states = self.x_embedder(hidden_states)
# 2. Context & noise refinement # 2. Context & noise refinement
for layer in self.context_refiner: for layer in self.context_refiner:
# NOTE: mask not used for performance encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
encoder_hidden_states = layer(
encoder_hidden_states, attention_mask if use_mask_in_transformer else None, encoder_rotary_emb
)
for layer in self.noise_refiner: for layer in self.noise_refiner:
# NOTE: mask not used for performance hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)
hidden_states = layer(
hidden_states, hidden_mask if use_mask_in_transformer else None, hidden_rotary_emb, temb # 3. Joint Transformer blocks
) max_seq_len = max(seq_lengths)
use_mask = len(set(seq_lengths)) > 1
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
attention_mask[i, :seq_len] = True
joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]
hidden_states = joint_hidden_states
# 3. Attention mask preparation
mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
for i in range(batch_size):
cap_len = encoder_hidden_len[i]
img_len = hidden_len[i]
mask[i, : cap_len + img_len] = True
padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len]
padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len]
hidden_states = padded_hidden_states
# 4. Transformer blocks
for layer in self.layers: for layer in self.layers:
# NOTE: mask not used for performance
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func( hidden_states = self._gradient_checkpointing_func(
layer, hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
) )
else: else:
hidden_states = layer(hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb) hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
# 5. Output norm & projection & unpatchify # 4. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
height_tokens = width_tokens = self.config.patch_size # 5. Unpatchify
p = self.config.patch_size
output = [] output = []
for i in range(len(hidden_sizes)): for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
height, width = hidden_sizes[i]
begin = encoder_hidden_len[i]
end = begin + (height // height_tokens) * (width // width_tokens)
output.append( output.append(
hidden_states[i][begin:end] hidden_states[i][encoder_seq_len:seq_len]
.view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels) .view(height // p, width // p, p, p, self.out_channels)
.permute(4, 0, 2, 1, 3) .permute(4, 0, 2, 1, 3)
.flatten(3, 4) .flatten(3, 4)
.flatten(1, 2) .flatten(1, 2)
......
...@@ -24,8 +24,6 @@ from ...models import AutoencoderKL ...@@ -24,8 +24,6 @@ from ...models import AutoencoderKL
from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import ( from ...utils import (
is_bs4_available,
is_ftfy_available,
is_torch_xla_available, is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -44,12 +42,6 @@ else: ...@@ -44,12 +42,6 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
pass
if is_ftfy_available():
pass
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -527,7 +519,6 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline): ...@@ -527,7 +519,6 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
cfg_trunc_ratio: float = 1.0, cfg_trunc_ratio: float = 1.0,
cfg_normalization: bool = True, cfg_normalization: bool = True,
use_mask_in_transformer: bool = True,
max_sequence_length: int = 256, max_sequence_length: int = 256,
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
""" """
...@@ -599,8 +590,6 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline): ...@@ -599,8 +590,6 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
The ratio of the timestep interval to apply normalization-based guidance scale. The ratio of the timestep interval to apply normalization-based guidance scale.
cfg_normalization (`bool`, *optional*, defaults to `True`): cfg_normalization (`bool`, *optional*, defaults to `True`):
Whether to apply normalization-based guidance scale. Whether to apply normalization-based guidance scale.
use_mask_in_transformer (`bool`, *optional*, defaults to `True`):
Whether to use attention mask in `Lumina2Transformer2DModel`. Set `False` for performance gain.
max_sequence_length (`int`, defaults to `256`): max_sequence_length (`int`, defaults to `256`):
Maximum sequence length to use with the `prompt`. Maximum sequence length to use with the `prompt`.
...@@ -706,8 +695,7 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline): ...@@ -706,8 +695,7 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
hidden_states=latents, hidden_states=latents,
timestep=current_timestep, timestep=current_timestep,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
attention_mask=prompt_attention_mask, encoder_attention_mask=prompt_attention_mask,
use_mask_in_transformer=use_mask_in_transformer,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -717,8 +705,7 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline): ...@@ -717,8 +705,7 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
hidden_states=latents, hidden_states=latents,
timestep=current_timestep, timestep=current_timestep,
encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states=negative_prompt_embeds,
attention_mask=negative_prompt_attention_mask, encoder_attention_mask=negative_prompt_attention_mask,
use_mask_in_transformer=use_mask_in_transformer,
return_dict=False, return_dict=False,
)[0] )[0]
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
......
...@@ -51,7 +51,7 @@ class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestC ...@@ -51,7 +51,7 @@ class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestC
"hidden_states": hidden_states, "hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states, "encoder_hidden_states": encoder_hidden_states,
"timestep": timestep, "timestep": timestep,
"attention_mask": attention_mask, "encoder_attention_mask": attention_mask,
} }
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment