"tests/git@developer.sourcefind.cn:norm/vllm.git" did not exist on "c07ece5ca490a90b2b19c33ab7da2d21e015d7bd"
Commit c413353e authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

add `encoder_hid_dim` to unet

`encoder_hid_dim` provides an additional projection for the input `encoder_hidden_states` from `encoder_hidden_dim` to `cross_attention_dim`
parent 8db5e5b3
...@@ -88,6 +88,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -88,6 +88,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
encoder_hid_dim (`int`, *optional*, defaults to None):
If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
...@@ -139,6 +141,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -139,6 +141,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280, cross_attention_dim: Union[int, Tuple[int]] = 1280,
encoder_hid_dim: Optional[int] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False, dual_cross_attention: bool = False,
use_linear_projection: bool = False, use_linear_projection: bool = False,
...@@ -224,6 +227,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -224,6 +227,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cond_proj_dim=time_cond_proj_dim, cond_proj_dim=time_cond_proj_dim,
) )
if encoder_hid_dim is not None:
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
else:
self.encoder_hid_proj = None
# class embedding # class embedding
if class_embed_type is None and num_class_embeds is not None: if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
...@@ -626,6 +634,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -626,6 +634,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
else: else:
emb = emb + class_emb emb = emb + class_emb
if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
...@@ -169,6 +169,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -169,6 +169,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
encoder_hid_dim (`int`, *optional*, defaults to None):
If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
...@@ -225,6 +227,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -225,6 +227,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280, cross_attention_dim: Union[int, Tuple[int]] = 1280,
encoder_hid_dim: Optional[int] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False, dual_cross_attention: bool = False,
use_linear_projection: bool = False, use_linear_projection: bool = False,
...@@ -316,6 +319,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -316,6 +319,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
cond_proj_dim=time_cond_proj_dim, cond_proj_dim=time_cond_proj_dim,
) )
if encoder_hid_dim is not None:
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
else:
self.encoder_hid_proj = None
# class embedding # class embedding
if class_embed_type is None and num_class_embeds is not None: if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
...@@ -718,6 +726,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -718,6 +726,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
else: else:
emb = emb + class_emb emb = emb + class_emb
if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment