Commit f6d4fc85 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

style: add ruff isort (#183)

parent 878f5a48
......@@ -7,11 +7,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from loguru import logger
# from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
__all__ = [
"XLMRobertaCLIP",
......
import torch
from diffusers.models.embeddings import get_timestep_embedding, get_3d_sincos_pos_embed
from diffusers.models.embeddings import get_3d_sincos_pos_embed, get_timestep_embedding
class CogvideoxPreInfer:
......
import torch
from safetensors import safe_open
import os
import glob
import math
import json
import math
import os
from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights
from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights
from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights
import torch
from safetensors import safe_open
from lightx2v.models.networks.cogvideox.infer.post_infer import CogvideoxPostInfer
from lightx2v.models.networks.cogvideox.infer.pre_infer import CogvideoxPreInfer
from lightx2v.models.networks.cogvideox.infer.transformer_infer import CogvideoxTransformerInfer
from lightx2v.models.networks.cogvideox.infer.post_infer import CogvideoxPostInfer
from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights
from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights
from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights
class CogvideoxModel:
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
class CogvideoxPostWeights:
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
class CogvideoxPreWeights:
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
class CogvideoxTransformerWeights:
......
from ..transformer_infer import HunyuanTransformerInfer
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
import torch
import numpy as np
import torch
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
from ..transformer_infer import HunyuanTransformerInfer
class HunyuanTransformerInferTeaCaching(HunyuanTransformerInfer):
......
from typing import Dict
import math
from typing import Dict
import torch
......
import torch
import math
import torch
from einops import rearrange
......
import torch
from einops import rearrange
from .utils_bf16 import apply_rotary_emb
from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
from .utils_bf16 import apply_rotary_emb
class HunyuanTransformerInfer(BaseTransformerInfer):
def __init__(self, config):
......
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from typing import Any, List, Tuple, Optional, Union, Dict
def rms_norm(x, weight, eps):
......
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from typing import Any, List, Tuple, Optional, Union, Dict
def rms_norm(x, weight, eps):
......
import json
import os
import torch
import json
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights
from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from loguru import logger
from safetensors import safe_open
from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import (
HunyuanTransformerInferTaylorCaching,
HunyuanTransformerInferTeaCaching,
HunyuanTransformerInferAdaCaching,
HunyuanTransformerInferCustomCaching,
HunyuanTransformerInferTaylorCaching,
HunyuanTransformerInferTeaCaching,
)
from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights
from lightx2v.utils.envs import *
from loguru import logger
from safetensors import safe_open
class HunyuanModel:
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
class HunyuanPostWeights(WeightModule):
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
class HunyuanPreWeights(WeightModule):
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER
class HunyuanTransformerWeights(WeightModule):
......
......@@ -2,9 +2,10 @@ try:
import flash_attn
except ModuleNotFoundError:
flash_attn = None
import math
import os
import safetensors
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
......
import glob
import os
import torch
import time
import glob
import torch
from safetensors import safe_open
from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
from lightx2v.common.ops.attn.radial_attn import MaskMap
class WanAudioModel(WanModel):
......
import os
import torch
from safetensors import safe_open
from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausVid,
)
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausVid,
)
from lightx2v.utils.envs import *
from safetensors import safe_open
class WanCausVidModel(WanModel):
......
import glob
import json
import os
import sys
import torch
import glob
import json
from loguru import logger
from safetensors import safe_open
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.utils.envs import *
from loguru import logger
class WanDistillModel(WanModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment