Commit f6d4fc85 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

style: add ruff isort (#183)

parent 878f5a48
import os
import gc import gc
import os
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER from loguru import logger
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.causvid_model import WanCausVidModel from lightx2v.models.networks.wan.causvid_model import WanCausVidModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from loguru import logger from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
import torch.distributed as dist from lightx2v.utils.registry_factory import RUNNER_REGISTER
@RUNNER_REGISTER("wan2.1_causvid") @RUNNER_REGISTER("wan2.1_causvid")
......
import os import os
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER from loguru import logger
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.distill_model import WanDistillModel from lightx2v.models.networks.wan.distill_model import WanDistillModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import cache_video from lightx2v.utils.utils import cache_video
from loguru import logger
@RUNNER_REGISTER("wan2.1_distill") @RUNNER_REGISTER("wan2.1_distill")
......
import os
import gc import gc
import os
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms.functional as TF
import torch.distributed as dist import torch.distributed as dist
from loguru import logger import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER from loguru import logger
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import Wan22MoeModel, WanModel
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import ( from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
WanScheduler4ChangingResolutionInterface, WanScheduler4ChangingResolutionInterface,
) )
...@@ -16,16 +20,14 @@ from lightx2v.models.schedulers.wan.feature_caching.scheduler import ( ...@@ -16,16 +20,14 @@ from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerCaching, WanSchedulerCaching,
WanSchedulerTaylorCaching, WanSchedulerTaylorCaching,
) )
from lightx2v.utils.utils import * from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel, Wan22MoeModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
from lightx2v.utils.utils import cache_video, best_output_size
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video
@RUNNER_REGISTER("wan2.1") @RUNNER_REGISTER("wan2.1")
......
import os
import gc import gc
import os
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER from loguru import logger
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler import WanSkyreelsV2DFScheduler from lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler import WanSkyreelsV2DFScheduler
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.registry_factory import RUNNER_REGISTER
import torch.distributed as dist
from loguru import logger
@RUNNER_REGISTER("wan2.1_skyreels_v2_df") @RUNNER_REGISTER("wan2.1_skyreels_v2_df")
......
import numpy as np
import torch import torch
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.models.embeddings import get_3d_rotary_pos_embed
import numpy as np from diffusers.utils.torch_utils import randn_tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.scheduler import BaseScheduler
......
from ..scheduler import HunyuanScheduler
import torch import torch
from ..scheduler import HunyuanScheduler
class HunyuanSchedulerTeaCaching(HunyuanScheduler): class HunyuanSchedulerTeaCaching(HunyuanScheduler):
def __init__(self, config): def __init__(self, config):
......
import torch from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from typing import Union, Tuple, List
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.scheduler import BaseScheduler
......
import torch import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
......
import os
import gc import gc
import math import math
import numpy as np import os
import torch
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from lightx2v.utils.envs import *
from lightx2v.models.schedulers.scheduler import BaseScheduler
from loguru import logger
from diffusers.configuration_utils import register_to_config import numpy as np
from torch import Tensor import torch
from diffusers import ( from diffusers import (
FlowMatchEulerDiscreteScheduler as FlowMatchEulerDiscreteSchedulerBase, # pyright: ignore FlowMatchEulerDiscreteScheduler as FlowMatchEulerDiscreteSchedulerBase, # pyright: ignore
) )
from diffusers.configuration_utils import register_to_config
from loguru import logger
from torch import Tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.envs import *
def unsqueeze_to_ndim(in_tensor: Tensor, tgt_n_dim: int): def unsqueeze_to_ndim(in_tensor: Tensor, tgt_n_dim: int):
......
import torch import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
......
import os
import math import math
import os
import numpy as np import numpy as np
import torch import torch
......
import gc
import math import math
from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import gc
from typing import List, Optional, Union
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.utils import masks_like from lightx2v.utils.utils import masks_like
......
import math import math
from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from typing import List, Optional, Tuple, Union
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
......
import torch
import numpy as np import numpy as np
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.models as models import torchvision.models as models
......
import torch
import torch.nn.functional as F
from math import exp from math import exp
import numpy as np import numpy as np
import torch
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......
import os import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from loguru import logger
import torch import torch
from loguru import logger
from torch.nn import functional as F from torch.nn import functional as F
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
......
...@@ -3,6 +3,7 @@ import torch.nn as nn ...@@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..model.warplayer import warp from ..model.warplayer import warp
# from train_log.refine import * # from train_log.refine import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......
...@@ -16,12 +16,11 @@ ...@@ -16,12 +16,11 @@
# Modified from diffusers==0.29.2 # Modified from diffusers==0.29.2
# #
# ============================================================================== # ==============================================================================
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
try: try:
...@@ -30,7 +29,6 @@ try: ...@@ -30,7 +29,6 @@ try:
except ImportError: except ImportError:
# Use this to be compatible with the original diffusers. # Use this to be compatible with the original diffusers.
from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
...@@ -41,7 +39,9 @@ from diffusers.models.attention_processor import ( ...@@ -41,7 +39,9 @@ from diffusers.models.attention_processor import (
) )
from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D from diffusers.utils.accelerate_utils import apply_forward_hook
from .vae import BaseOutput, DecoderCausal3D, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
@dataclass @dataclass
......
import os import os
import torch import torch
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDistribution from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDistribution
......
...@@ -21,15 +21,12 @@ from typing import Optional, Tuple, Union ...@@ -21,15 +21,12 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from einops import rearrange
from diffusers.utils import logging
from diffusers.models.activations import get_activation from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import SpatialNorm from diffusers.models.attention_processor import Attention, SpatialNorm
from diffusers.models.attention_processor import Attention from diffusers.models.normalization import AdaGroupNorm, RMSNorm
from diffusers.models.normalization import AdaGroupNorm from diffusers.utils import logging
from diffusers.models.normalization import RMSNorm from einops import rearrange
from torch import nn
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment