Commit f6d4fc85 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

style: add ruff isort (#183)

parent 878f5a48
...@@ -7,11 +7,11 @@ import torch ...@@ -7,11 +7,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms as T import torchvision.transforms as T
from loguru import logger
# from lightx2v.attentions import attention # from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight from lightx2v.common.ops.attn import TorchSDPAWeight
from loguru import logger from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
__all__ = [ __all__ = [
"XLMRobertaCLIP", "XLMRobertaCLIP",
......
import torch import torch
from diffusers.models.embeddings import get_timestep_embedding, get_3d_sincos_pos_embed from diffusers.models.embeddings import get_3d_sincos_pos_embed, get_timestep_embedding
class CogvideoxPreInfer: class CogvideoxPreInfer:
......
import torch
from safetensors import safe_open
import os
import glob import glob
import math
import json import json
import math
import os
from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights import torch
from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights from safetensors import safe_open
from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights
from lightx2v.models.networks.cogvideox.infer.post_infer import CogvideoxPostInfer
from lightx2v.models.networks.cogvideox.infer.pre_infer import CogvideoxPreInfer from lightx2v.models.networks.cogvideox.infer.pre_infer import CogvideoxPreInfer
from lightx2v.models.networks.cogvideox.infer.transformer_infer import CogvideoxTransformerInfer from lightx2v.models.networks.cogvideox.infer.transformer_infer import CogvideoxTransformerInfer
from lightx2v.models.networks.cogvideox.infer.post_infer import CogvideoxPostInfer from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights
from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights
from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights
class CogvideoxModel: class CogvideoxModel:
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
class CogvideoxPostWeights: class CogvideoxPostWeights:
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
class CogvideoxPreWeights: class CogvideoxPreWeights:
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
class CogvideoxTransformerWeights: class CogvideoxTransformerWeights:
......
from ..transformer_infer import HunyuanTransformerInfer
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
import torch
import numpy as np import numpy as np
import torch
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
from ..transformer_infer import HunyuanTransformerInfer
class HunyuanTransformerInferTeaCaching(HunyuanTransformerInfer): class HunyuanTransformerInferTeaCaching(HunyuanTransformerInfer):
......
from typing import Dict
import math import math
from typing import Dict
import torch import torch
......
import torch
import math import math
import torch
from einops import rearrange from einops import rearrange
......
import torch import torch
from einops import rearrange from einops import rearrange
from .utils_bf16 import apply_rotary_emb
from lightx2v.common.offload.manager import WeightAsyncStreamManager from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from .utils_bf16 import apply_rotary_emb
class HunyuanTransformerInfer(BaseTransformerInfer): class HunyuanTransformerInfer(BaseTransformerInfer):
def __init__(self, config): def __init__(self, config):
......
from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from typing import Any, List, Tuple, Optional, Union, Dict
def rms_norm(x, weight, eps): def rms_norm(x, weight, eps):
......
from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from typing import Any, List, Tuple, Optional, Union, Dict
def rms_norm(x, weight, eps): def rms_norm(x, weight, eps):
......
import json
import os import os
import torch import torch
import json from loguru import logger
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights from safetensors import safe_open
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights
from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import ( from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import (
HunyuanTransformerInferTaylorCaching,
HunyuanTransformerInferTeaCaching,
HunyuanTransformerInferAdaCaching, HunyuanTransformerInferAdaCaching,
HunyuanTransformerInferCustomCaching, HunyuanTransformerInferCustomCaching,
HunyuanTransformerInferTaylorCaching,
HunyuanTransformerInferTeaCaching,
) )
from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger
from safetensors import safe_open
class HunyuanModel: class HunyuanModel:
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
class HunyuanPostWeights(WeightModule): class HunyuanPostWeights(WeightModule):
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
class HunyuanPreWeights(WeightModule): class HunyuanPreWeights(WeightModule):
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER
class HunyuanTransformerWeights(WeightModule): class HunyuanTransformerWeights(WeightModule):
......
...@@ -2,9 +2,10 @@ try: ...@@ -2,9 +2,10 @@ try:
import flash_attn import flash_attn
except ModuleNotFoundError: except ModuleNotFoundError:
flash_attn = None flash_attn = None
import math
import os import os
import safetensors import safetensors
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
......
import glob
import os import os
import torch
import time import time
import glob
import torch
from safetensors import safe_open
from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
from lightx2v.common.ops.attn.radial_attn import MaskMap
class WanAudioModel(WanModel): class WanAudioModel(WanModel):
......
import os import os
import torch import torch
from safetensors import safe_open
from lightx2v.common.ops.attn.radial_attn import MaskMap from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausVid,
)
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausVid,
)
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from safetensors import safe_open
class WanCausVidModel(WanModel): class WanCausVidModel(WanModel):
......
import glob
import json
import os import os
import sys import sys
import torch import torch
import glob from loguru import logger
import json
from safetensors import safe_open from safetensors import safe_open
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger
class WanDistillModel(WanModel): class WanDistillModel(WanModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment