• Pedro Cuenca's avatar
    Allow `UNet2DModel` to use arbitrary class embeddings (#2080) · 915a5636
    Pedro Cuenca authored
    * Allow `UNet2DModel` to use arbitrary class embeddings.
    
    We can currently use class conditioning in `UNet2DConditionModel`, but
    not in `UNet2DModel`. However, `UNet2DConditionModel` requires text
    conditioning too, which is unrelated to other types of conditioning.
    This commit makes it possible for `UNet2DModel` to be conditioned on
    entities other than timesteps. This is useful for training /
    research purposes. We can currently train models to perform
    unconditional image generation or text-to-image generation, but it's not
    straightforward to train a model to perform class-conditioned image
    generation, if text conditioning is not required.
    
    We could potentiall use `UNet2DConditionModel` for class-conditioning
    without text embeddings by using down/up blocks without
    cross-conditioning. However:
    - The mid block currently requires cross attention.
    - We are required to provide `encoder_hidden_states` to `forward`.
    
    * Style
    
    * Align class conditioning, add docstring for `num_class_embeds`.
    
    * Copy docstring to versatile_diffusion UNetFlatConditionModel
    915a5636
unet_2d_condition.py 24.3 KB