"vscode:/vscode.git/clone" did not exist on "ed7c95faf5bf657c41d0838f002100e6a736b4e6"
Commit f6d4fc85 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

style: add ruff isort (#183)

parent 878f5a48
import math import math
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
from loguru import logger from loguru import logger
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
......
import torch
import math import math
from ..utils import rope_params, sinusoidal_embedding_1d
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer import torch
from loguru import logger from loguru import logger
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from ..utils import rope_params, sinusoidal_embedding_1d
class WanAudioPreInfer(WanPreInfer): class WanAudioPreInfer(WanPreInfer):
def __init__(self, config): def __init__(self, config):
......
import torch
import math import math
from ..utils import compute_freqs, compute_freqs_causvid, apply_rotary_emb
import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from ..transformer_infer import WanTransformerInfer from ..transformer_infer import WanTransformerInfer
from ..utils import apply_rotary_emb, compute_freqs, compute_freqs_causvid
class WanTransformerInferCausVid(WanTransformerInfer): class WanTransformerInferCausVid(WanTransformerInfer):
......
import torch import torch
from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer
from lightx2v.models.networks.wan.infer.utils import pad_freqs from lightx2v.models.networks.wan.infer.utils import pad_freqs
......
from ..transformer_infer import WanTransformerInfer
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
import torch
import numpy as np
import gc import gc
import numpy as np
import torch
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
from ..transformer_infer import WanTransformerInfer
class WanTransformerInferCaching(WanTransformerInfer): class WanTransformerInferCaching(WanTransformerInfer):
def __init__(self, config): def __init__(self, config):
......
import math import math
import torch import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
......
import torch import torch
from diffusers.models.embeddings import TimestepEmbedding from diffusers.models.embeddings import TimestepEmbedding
from .utils import rope_params, sinusoidal_embedding_1d, guidance_scale_embedding
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from .utils import guidance_scale_embedding, rope_params, sinusoidal_embedding_1d
class WanPreInfer: class WanPreInfer:
def __init__(self, config): def __init__(self, config):
......
from functools import partial
import torch import torch
from .utils import compute_freqs, compute_freqs_audio, apply_rotary_emb, apply_rotary_emb_chunk
from lightx2v.common.offload.manager import ( from lightx2v.common.offload.manager import (
WeightAsyncStreamManager,
LazyWeightAsyncStreamManager, LazyWeightAsyncStreamManager,
WeightAsyncStreamManager,
) )
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from functools import partial
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
......
import gc
import os import os
import torch import torch
from safetensors import safe_open
from loguru import logger from loguru import logger
import gc from safetensors import safe_open
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
......
import glob
import json
import os import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import glob from loguru import logger
import json from safetensors import safe_open
from lightx2v.common.ops.attn import MaskMap from lightx2v.common.ops.attn import MaskMap
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights from lightx2v.models.networks.wan.infer.dist_infer.transformer_infer import WanTransformerDistInfer
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import ( from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferTeaCaching,
WanTransformerInferTaylorCaching,
WanTransformerInferAdaCaching, WanTransformerInferAdaCaching,
WanTransformerInferCustomCaching, WanTransformerInferCustomCaching,
WanTransformerInferFirstBlock,
WanTransformerInferDualBlock, WanTransformerInferDualBlock,
WanTransformerInferDynamicBlock, WanTransformerInferDynamicBlock,
WanTransformerInferFirstBlock,
WanTransformerInferTaylorCaching,
WanTransformerInferTeaCaching,
)
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
) )
from lightx2v.models.networks.wan.infer.dist_infer.transformer_infer import WanTransformerDistInfer
from safetensors import safe_open
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.utils import * from lightx2v.utils.utils import *
from loguru import logger
class WanModel: class WanModel:
......
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import ( from lightx2v.utils.registry_factory import (
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER, MM_WEIGHT_REGISTER,
TENSOR_REGISTER, TENSOR_REGISTER,
LN_WEIGHT_REGISTER,
) )
from lightx2v.common.modules.weight_module import WeightModule
class WanPostWeights(WeightModule): class WanPostWeights(WeightModule):
......
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import ( from lightx2v.utils.registry_factory import (
MM_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
CONV3D_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
) )
from lightx2v.common.modules.weight_module import WeightModule
class WanPreWeights(WeightModule): class WanPreWeights(WeightModule):
......
import torch
import os import os
import torch
from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.utils.registry_factory import ( from lightx2v.utils.registry_factory import (
MM_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER, LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
RMS_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER,
TENSOR_REGISTER, TENSOR_REGISTER,
ATTN_WEIGHT_REGISTER,
) )
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from safetensors import safe_open
class WanTransformerWeights(WeightModule): class WanTransformerWeights(WeightModule):
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, Optional, Union, List, Protocol from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
from lightx2v.utils.utils import save_videos_grid from lightx2v.utils.utils import save_videos_grid
......
from diffusers.utils import export_to_video
import imageio import imageio
import numpy as np import numpy as np
from diffusers.utils import export_to_video
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl
from lightx2v.models.networks.cogvideox.model import CogvideoxModel from lightx2v.models.networks.cogvideox.model import CogvideoxModel
from lightx2v.models.video_encoders.hf.cogvideox.model import CogvideoxVAE from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.cogvideox.scheduler import CogvideoxXDPMScheduler from lightx2v.models.schedulers.cogvideox.scheduler import CogvideoxXDPMScheduler
from lightx2v.models.video_encoders.hf.cogvideox.model import CogvideoxVAE
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
@RUNNER_REGISTER("cogvideox") @RUNNER_REGISTER("cogvideox")
......
import gc import gc
from PIL import Image
from loguru import logger
import requests import requests
from requests.exceptions import RequestException
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.generate_task_id import generate_task_id from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image, cache_video from lightx2v.utils.utils import cache_video, save_to_video, vae_to_comfyui_image
from .base_runner import BaseRunner from .base_runner import BaseRunner
......
from lightx2v.utils.profiler import ProfilingContext4Debug
from loguru import logger from loguru import logger
from lightx2v.utils.profiler import ProfilingContext4Debug
class GraphRunner: class GraphRunner:
def __init__(self, runner): def __init__(self, runner):
......
import os import os
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
from PIL import Image from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching, HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.models.networks.hunyuan.model import HunyuanModel from lightx2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching, HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.utils.utils import save_videos_grid
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_videos_grid
@RUNNER_REGISTER("hunyuan") @RUNNER_REGISTER("hunyuan")
......
import os
import gc import gc
import os
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import subprocess
import torchaudio as ta import torchaudio as ta
from PIL import Image from PIL import Image
from contextlib import contextmanager
from typing import Optional, Tuple, List, Dict, Any
from dataclasses import dataclass
from loguru import logger
from einops import rearrange from einops import rearrange
from transformers import AutoFeatureExtractor from loguru import logger
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize
from transformers import AutoFeatureExtractor
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner, MultiModelStruct
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.networks.wan.audio_model import WanAudioModel, Wan22MoeAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
@contextmanager @contextmanager
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment