Unverified Commit a9a5b14f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] refactor transformers 2d into multiple init variants. (#7491)

* refactor transformers 2d into multiple legacy variants.

* fix: init.

* fix recursive init.

* add inits.

* make transformer block creation more modular.

* complete refactor.

* remove forward

* debug

* remove legacy blocks and refactor within the module itself.

* remove print

* guard caption projection

* remove fetcher.

* reduce the number of args.

* fix: norm_type

* group variables that are shared.

* remove _get_transformer_blocks

* harmonize the init function signatures.

* transformer_blocks to common

* repeat .
parent aa190259
...@@ -102,6 +102,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -102,6 +102,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
interpolation_scale: float = None, interpolation_scale: float = None,
): ):
super().__init__() super().__init__()
# Validate inputs.
if patch_size is not None: if patch_size is not None:
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
raise NotImplementedError( raise NotImplementedError(
...@@ -112,10 +114,16 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -112,10 +114,16 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
) )
# Set some common variables used across the board.
self.use_linear_projection = use_linear_projection self.use_linear_projection = use_linear_projection
self.interpolation_scale = interpolation_scale
self.caption_channels = caption_channels
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.gradient_checkpointing = False
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration # Define whether input is continuous or discrete depending on configuration
...@@ -150,104 +158,167 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -150,104 +158,167 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
) )
# 2. Define input layers # 2. Initialize the right blocks.
# These functions follow a common structure:
# a. Initialize the input blocks. b. Initialize the transformer blocks.
# c. Initialize the output blocks and other projection blocks when necessary.
if self.is_input_continuous: if self.is_input_continuous:
self.in_channels = in_channels self._init_continuous_input(norm_type=norm_type)
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" self._init_vectorized_inputs(norm_type=norm_type)
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" elif self.is_input_patches:
self._init_patched_inputs(norm_type=norm_type)
self.height = sample_size def _init_continuous_input(self, norm_type):
self.width = sample_size self.norm = torch.nn.GroupNorm(
self.num_vector_embeds = num_vector_embeds num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
self.num_latent_pixels = self.height * self.width )
if self.use_linear_projection:
self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
else:
self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
self.latent_image_embedding = ImagePositionalEmbeddings( self.transformer_blocks = nn.ModuleList(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width [
) BasicTransformerBlock(
elif self.is_input_patches: self.inner_dim,
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
cross_attention_dim=self.config.cross_attention_dim,
activation_fn=self.config.activation_fn,
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
attention_bias=self.config.attention_bias,
only_cross_attention=self.config.only_cross_attention,
double_self_attention=self.config.double_self_attention,
upcast_attention=self.config.upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
attention_type=self.config.attention_type,
)
for _ in range(self.config.num_layers)
]
)
self.height = sample_size if self.use_linear_projection:
self.width = sample_size self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
else:
self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
self.patch_size = patch_size def _init_vectorized_inputs(self, norm_type):
interpolation_scale = ( assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1) assert (
) self.config.num_vector_embeds is not None
self.pos_embed = PatchEmbed( ), "Transformer2DModel over discrete input must provide num_embed"
height=sample_size,
width=sample_size, self.height = self.config.sample_size
patch_size=patch_size, self.width = self.config.sample_size
in_channels=in_channels, self.num_latent_pixels = self.height * self.width
embed_dim=inner_dim,
interpolation_scale=interpolation_scale, self.latent_image_embedding = ImagePositionalEmbeddings(
) num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[ [
BasicTransformerBlock( BasicTransformerBlock(
inner_dim, self.inner_dim,
num_attention_heads, self.config.num_attention_heads,
attention_head_dim, self.config.attention_head_dim,
dropout=dropout, dropout=self.config.dropout,
cross_attention_dim=cross_attention_dim, cross_attention_dim=self.config.cross_attention_dim,
activation_fn=activation_fn, activation_fn=self.config.activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm, num_embeds_ada_norm=self.config.num_embeds_ada_norm,
attention_bias=attention_bias, attention_bias=self.config.attention_bias,
only_cross_attention=only_cross_attention, only_cross_attention=self.config.only_cross_attention,
double_self_attention=double_self_attention, double_self_attention=self.config.double_self_attention,
upcast_attention=upcast_attention, upcast_attention=self.config.upcast_attention,
norm_type=norm_type, norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine, norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=norm_eps, norm_eps=self.config.norm_eps,
attention_type=attention_type, attention_type=self.config.attention_type,
) )
for d in range(num_layers) for _ in range(self.config.num_layers)
] ]
) )
# 4. Define output layers self.norm_out = nn.LayerNorm(self.inner_dim)
self.out_channels = in_channels if out_channels is None else out_channels self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections def _init_patched_inputs(self, norm_type):
if use_linear_projection: assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
self.proj_out = nn.Linear(inner_dim, in_channels)
else: self.height = self.config.sample_size
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) self.width = self.config.sample_size
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim) self.patch_size = self.config.patch_size
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) interpolation_scale = (
elif self.is_input_patches and norm_type != "ada_norm_single": self.config.interpolation_scale
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) if self.config.interpolation_scale is not None
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) else max(self.config.sample_size // 64, 1)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) )
elif self.is_input_patches and norm_type == "ada_norm_single": self.pos_embed = PatchEmbed(
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) height=self.config.sample_size,
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) width=self.config.sample_size,
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) patch_size=self.config.patch_size,
in_channels=self.in_channels,
# 5. PixArt-Alpha blocks. embed_dim=self.inner_dim,
interpolation_scale=interpolation_scale,
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
cross_attention_dim=self.config.cross_attention_dim,
activation_fn=self.config.activation_fn,
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
attention_bias=self.config.attention_bias,
only_cross_attention=self.config.only_cross_attention,
double_self_attention=self.config.double_self_attention,
upcast_attention=self.config.upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
attention_type=self.config.attention_type,
)
for _ in range(self.config.num_layers)
]
)
if self.config.norm_type != "ada_norm_single":
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
)
elif self.config.norm_type == "ada_norm_single":
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.proj_out = nn.Linear(
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
)
# PixArt-Alpha blocks.
self.adaln_single = None self.adaln_single = None
self.use_additional_conditions = False self.use_additional_conditions = False
if norm_type == "ada_norm_single": if self.config.norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128 self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name # additional conditions until we find better name
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=self.use_additional_conditions
)
self.caption_projection = None self.caption_projection = None
if caption_channels is not None: if self.caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels, hidden_size=self.inner_dim
self.gradient_checkpointing = False )
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"): if hasattr(module, "gradient_checkpointing"):
...@@ -361,7 +432,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -361,7 +432,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
) )
# 2. Blocks # 2. Blocks
if self.caption_projection is not None: if self.is_input_patches and self.caption_projection is not None:
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment