Unverified Commit 0c0c2224 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

FlaxUNet2DConditionOutput @flax.struct.dataclass (#550)

parent d09bbae5
from dataclasses import dataclass
from typing import Tuple, Union from typing import Tuple, Union
import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -19,7 +19,7 @@ from .unet_blocks_flax import ( ...@@ -19,7 +19,7 @@ from .unet_blocks_flax import (
) )
@dataclass @flax.struct.dataclass
class FlaxUNet2DConditionOutput(BaseOutput): class FlaxUNet2DConditionOutput(BaseOutput):
""" """
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment