"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "be6f6c2927dc03b6103af8d48a961562dd5d68d5"
Unverified Commit abd922bd authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[docs] unet type hints (#7134)

update
parent fa633ed6
...@@ -204,7 +204,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -204,7 +204,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False, resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0, resnet_out_scale_factor: float = 1.0,
time_embedding_type: str = "positional", time_embedding_type: str = "positional",
time_embedding_dim: Optional[int] = None, time_embedding_dim: Optional[int] = None,
time_embedding_act_fn: Optional[str] = None, time_embedding_act_fn: Optional[str] = None,
...@@ -217,7 +217,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -217,7 +217,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
class_embeddings_concat: bool = False, class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None, mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None, cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads=64, addition_embed_type_num_heads: int = 64,
): ):
super().__init__() super().__init__()
...@@ -485,9 +485,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -485,9 +485,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
up_block_types: Tuple[str], up_block_types: Tuple[str],
only_cross_attention: Union[bool, Tuple[bool]], only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int], block_out_channels: Tuple[int],
layers_per_block: [int, Tuple[int]], layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]], cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]], transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
reverse_transformer_layers_per_block: bool, reverse_transformer_layers_per_block: bool,
attention_head_dim: int, attention_head_dim: int,
num_attention_heads: Optional[Union[int, Tuple[int]]], num_attention_heads: Optional[Union[int, Tuple[int]]],
...@@ -762,7 +762,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -762,7 +762,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
self.set_attn_processor(processor) self.set_attn_processor(processor)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
...@@ -831,7 +831,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -831,7 +831,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
if hasattr(module, "gradient_checkpointing"): if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def enable_freeu(self, s1, s2, b1, b2): def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied. The suffixes after the scaling factors represent the stage blocks where they are being applied.
...@@ -953,7 +953,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -953,7 +953,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
return class_emb return class_emb
def get_aug_embed( def get_aug_embed(
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
aug_emb = None aug_emb = None
if self.config.addition_embed_type == "text": if self.config.addition_embed_type == "text":
...@@ -1004,7 +1004,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -1004,7 +1004,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
aug_emb = self.add_embedding(image_embs, hint) aug_emb = self.add_embedding(image_embs, hint)
return aug_emb return aug_emb
def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor: def process_encoder_hidden_states(
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
) -> torch.Tensor:
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment