Unverified Commit 4a343077 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

add: utility to format our docs too 📜 (#7314)

* add: utility to format our docs too 📜

* debugging saga

* fix: message

* checking

* should be fixed.

* revert pipeline_fixture

* remove empty line

* make style

* fix: setup.py

* style.
parent 8e963d1c
......@@ -32,9 +32,7 @@ jobs:
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: |
ruff check examples tests src utils scripts
ruff format examples tests src utils scripts --check
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
......@@ -53,7 +51,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
- name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
......
......@@ -40,9 +40,7 @@ jobs:
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: |
ruff check examples tests src utils scripts
ruff format examples tests src utils scripts --check
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
......@@ -61,7 +59,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
- name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
......
......@@ -42,6 +42,7 @@ repo-consistency:
quality:
ruff check $(check_dirs) setup.py
ruff format --check $(check_dirs) setup.py
doc-builder style src/diffusers docs/source --max_len 119 --check_only
python utils/check_doc_toc.py
# Format source code automatically and check is there are any problems left that need manual fixing
......@@ -55,6 +56,7 @@ extra_style_checks:
style:
ruff check $(check_dirs) setup.py --fix
ruff format $(check_dirs) setup.py
doc-builder style src/diffusers docs/source --max_len 119
${MAKE} autogenerate_code
${MAKE} extra_style_checks
......
......@@ -134,6 +134,7 @@ _deps = [
"torchvision",
"transformers>=4.25.1",
"urllib3<=2.0.0",
"black",
]
# this is a lookup table with items like:
......
......@@ -42,4 +42,5 @@ deps = {
"torchvision": "torchvision",
"transformers": "transformers>=4.25.1",
"urllib3": "urllib3<=2.0.0",
"black": "black",
}
......@@ -173,8 +173,9 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
"""
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
processing are 512x512, the region will be expanded to 128x128.
Args:
mask_image (PIL.Image.Image): Mask image.
......@@ -183,7 +184,8 @@ class VaeImageProcessor(ConfigMixin):
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
Returns:
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
matches the original aspect ratio.
"""
mask_image = mask_image.convert("L")
......@@ -265,7 +267,8 @@ class VaeImageProcessor(ConfigMixin):
height: int,
) -> PIL.Image.Image:
"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, filling empty with data from image.
Args:
image: The image to resize.
......@@ -309,7 +312,8 @@ class VaeImageProcessor(ConfigMixin):
height: int,
) -> PIL.Image.Image:
"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, cropping the excess.
Args:
image: The image to resize.
......@@ -346,12 +350,12 @@ class VaeImageProcessor(ConfigMixin):
The width to resize to.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
within the specified width and height, and it may not maintaining the original aspect ratio.
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
within the dimensions, filling empty with data from image.
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
within the dimensions, cropping the excess.
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
supported for PIL image input.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
......@@ -456,19 +460,21 @@ class VaeImageProcessor(ConfigMixin):
Args:
image (`pipeline_image_input`):
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
supported formats.
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
height.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
within the specified width and height, and it may not maintaining the original aspect ratio.
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
within the dimensions, filling empty with data from image.
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
within the dimensions, cropping the excess.
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
"""
......@@ -930,8 +936,8 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
@staticmethod
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
"""
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
Args:
mask (`torch.FloatTensor`):
......
......@@ -67,17 +67,18 @@ class IPAdapterMixin:
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
If a list is passed, it should have the same length as `weight_name`.
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
for example, `image_encoder_folder="different_subfolder/image_encoder"`.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
`image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
......
......@@ -20,7 +20,8 @@ from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available
class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
more details about adapters and injecting them in a transformer-based model, check out the PEFT [documentation](https://huggingface.co/docs/peft/index).
more details about adapters and injecting them in a transformer-based model, check out the PEFT
[documentation](https://huggingface.co/docs/peft/index).
Install the latest version of PEFT, and use this mixin to:
......@@ -143,8 +144,8 @@ class PeftAdapterMixin:
def enable_adapters(self) -> None:
"""
Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the
list of adapters to enable.
Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
adapters to enable.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
[documentation](https://huggingface.co/docs/peft).
......
......@@ -198,19 +198,24 @@ class FromSingleFileMixin:
model_type (`str`, *optional*):
The type of model to load. If not provided, the model type will be inferred from the checkpoint file.
image_size (`int`, *optional*):
The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE model.
The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE
model.
load_safety_checker (`bool`, *optional*, defaults to `False`):
Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a `safety_checker` component is passed to the `kwargs`.
Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a
`safety_checker` component is passed to the `kwargs`.
num_in_channels (`int`, *optional*):
Specify the number of input channels for the UNet model. Read more about how to configure UNet model with this parameter
Specify the number of input channels for the UNet model. Read more about how to configure UNet model
with this parameter
[here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters).
scaling_factor (`float`, *optional*):
The scaling factor to use for the VAE model. If not provided, it is inferred from the config file first.
If the scaling factor is not found in the config file, the default value 0.18215 is used.
The scaling factor to use for the VAE model. If not provided, it is inferred from the config file
first. If the scaling factor is not found in the config file, the default value 0.18215 is used.
scheduler_type (`str`, *optional*):
The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint file.
The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint
file.
prediction_type (`str`, *optional*):
The type of prediction to load. If not provided, the prediction type will be inferred from the checkpoint file.
The type of prediction to load. If not provided, the prediction type will be inferred from the
checkpoint file.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
......
......@@ -487,20 +487,35 @@ class TextualInversionLoaderMixin:
# Example 3: unload from SDXL
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model")
embedding_path = hf_hub_download(
repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model"
)
# load embeddings to the text encoders
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(
state_dict["clip_l"],
token=["<s0>", "<s1>"],
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
pipeline.load_textual_inversion(
state_dict["clip_g"],
token=["<s0>", "<s1>"],
text_encoder=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer_2,
)
# Unload explicitly from both text encoders abd tokenizers
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
pipeline.unload_textual_inversion(
tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
)
pipeline.unload_textual_inversion(
tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2
)
```
"""
......
......@@ -74,37 +74,24 @@ def _maybe_expand_lora_scales_for_one_adapter(
E.g. turns
```python
scales = {
'down': 2,
'mid': 3,
'up': {
'block_0': 4,
'block_1': [5, 6, 7]
}
}
blocks_with_transformer = {
'down': [1,2],
'up': [0,1]
}
transformer_per_block = {
'down': 2,
'up': 3
}
scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
transformer_per_block = {"down": 2, "up": 3}
```
into
```python
{
'down.block_1.0': 2,
'down.block_1.1': 2,
'down.block_2.0': 2,
'down.block_2.1': 2,
'mid': 3,
'up.block_0.0': 4,
'up.block_0.1': 4,
'up.block_0.2': 4,
'up.block_1.0': 5,
'up.block_1.1': 6,
'up.block_1.2': 7,
"down.block_1.0": 2,
"down.block_1.1": 2,
"down.block_2.0": 2,
"down.block_2.1": 2,
"mid": 3,
"up.block_0.0": 4,
"up.block_0.1": 4,
"up.block_0.2": 4,
"up.block_1.0": 5,
"up.block_1.1": 6,
"up.block_1.2": 7,
}
```
"""
......
......@@ -1298,9 +1298,9 @@ class AttnProcessor2_0:
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
......
......@@ -453,8 +453,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
......
......@@ -329,15 +329,15 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of
a plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise
a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
channel_order = self.controlnet_conditioning_channel_order
if channel_order == "bgr":
......
......@@ -795,16 +795,13 @@ class IPAdapterPlusImageProjection(nn.Module):
Args:
----
embed_dims (int): The feature dimension. Defaults to 768.
output_dims (int): The number of output channels, that is the same
number of the channels in the
`unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int): The number of hidden channels. Defaults to 1280.
depth (int): The number of blocks. Defaults to 8.
dim_head (int): The number of head channels. Defaults to 64.
heads (int): Parallel attention heads. Defaults to 16.
num_queries (int): The number of queries. Defaults to 8.
ffn_ratio (float): The expansion ratio of feedforward network hidden
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
that is the same
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
Defaults to 16. num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
of feedforward network hidden
layer channels. Defaults to 4.
"""
......
......@@ -202,8 +202,8 @@ class ResnetBlock2D(nn.Module):
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift"
for a stronger conditioning with scale and shift.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
......
......@@ -120,7 +120,8 @@ class DualTransformer2DModel(nn.Module):
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
......
......@@ -294,8 +294,8 @@ class TransformerSpatioTemporalModel(nn.Module):
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
images, 0 indicates that the input contains video frames.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
tuple.
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
plain tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
......
......@@ -865,8 +865,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
......@@ -1093,8 +1093,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
Returns:
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
otherwise a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
......
......@@ -76,7 +76,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
is skipped.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
......@@ -350,15 +351,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of
a plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
# 1. time
if not isinstance(timesteps, jnp.ndarray):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment