"vscode:/vscode.git/clone" did not exist on "a9b0d50bafa26c4f2a9e665ff7b805face3f24e6"
Commit f6d4fc85 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

style: add ruff isort (#183)

parent 878f5a48
import torch import torch
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from loguru import logger from loguru import logger
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
if torch.cuda.get_device_capability(0) == (8, 9): if torch.cuda.get_device_capability(0) == (8, 9):
try: try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
......
import torch import torch
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from loguru import logger
import torch.nn as nn import torch.nn as nn
from loguru import logger
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
try: try:
from spas_sage_attn.autotune import SparseAttentionMeansim from spas_sage_attn.autotune import SparseAttentionMeansim
......
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@ATTN_WEIGHT_REGISTER("torch_sdpa") @ATTN_WEIGHT_REGISTER("torch_sdpa")
class TorchSDPAWeight(AttnWeightTemplate): class TorchSDPAWeight(AttnWeightTemplate):
......
import torch import torch
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
import torch.distributed as dist import torch.distributed as dist
from .utils.all2all import all2all_seq2head, all2all_head2seq
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
from .utils.all2all import all2all_head2seq, all2all_seq2head
@ATTN_WEIGHT_REGISTER("ulysses") @ATTN_WEIGHT_REGISTER("ulysses")
......
from typing import Optional from typing import Optional
from loguru import logger
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger
class RingComm: class RingComm:
......
import torch
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch
from lightx2v.utils.registry_factory import CONV2D_WEIGHT_REGISTER from lightx2v.utils.registry_factory import CONV2D_WEIGHT_REGISTER
......
import torch
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch
from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER
......
import torch
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer import torch
from lightx2v.utils.envs import *
from loguru import logger from loguru import logger
from lightx2v.utils.envs import *
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
except ImportError: except ImportError:
......
import torch import torch
from .mm_weight import MMWeight
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
from .mm_weight import MMWeight
@MM_WEIGHT_REGISTER("Calib") @MM_WEIGHT_REGISTER("Calib")
......
from .rms_norm_weight import *
from .layer_norm_weight import * from .layer_norm_weight import *
from .rms_norm_weight import *
import torch
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
class LNWeightTemplate(metaclass=ABCMeta): class LNWeightTemplate(metaclass=ABCMeta):
......
import torch
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
try: try:
import sgl_kernel import sgl_kernel
......
import torch import torch
from lightx2v.utils.registry_factory import TENSOR_REGISTER
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import TENSOR_REGISTER
@TENSOR_REGISTER("Default") @TENSOR_REGISTER("Default")
......
import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch import torch
import math
class BaseTransformerInfer(ABC): class BaseTransformerInfer(ABC):
......
import argparse import argparse
import torch
import torch.distributed as dist
import json import json
from lightx2v.utils.envs import * import torch
from lightx2v.utils.utils import seed_all import torch.distributed as dist
from lightx2v.utils.profiler import ProfilingContext from loguru import logger
from lightx2v.utils.set_config import set_config, print_config
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.common.ops import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner, Wan22MoeRunner from lightx2v.models.runners.wan.wan_audio_runner import Wan22MoeAudioRunner, WanAudioRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner, Wan22MoeAudioRunner from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.utils.envs import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.common.ops import * from lightx2v.utils.set_config import print_config, set_config
from loguru import logger from lightx2v.utils.utils import seed_all
def init_runner(config): def init_runner(config):
......
import torch import torch
from transformers import CLIPTextModel, AutoTokenizer
from loguru import logger from loguru import logger
from transformers import AutoTokenizer, CLIPTextModel
class TextEncoderHFClipModel: class TextEncoderHFClipModel:
......
import torch import torch
from transformers import AutoModel, AutoTokenizer
from loguru import logger from loguru import logger
from transformers import AutoModel, AutoTokenizer
class TextEncoderHFLlamaModel: class TextEncoderHFLlamaModel:
......
import torch
from PIL import Image
import numpy as np import numpy as np
import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
from transformers import LlavaForConditionalGeneration, CLIPImageProcessor, AutoTokenizer from PIL import Image
from loguru import logger from loguru import logger
from transformers import AutoTokenizer, CLIPImageProcessor, LlavaForConditionalGeneration
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0): def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
......
...@@ -3,14 +3,15 @@ ...@@ -3,14 +3,15 @@
import logging import logging
import math import math
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from .tokenizer import HuggingfaceTokenizer
__all__ = [ __all__ = [
"T5Model", "T5Model",
......
import torch
import os import os
import torch
from transformers import T5EncoderModel, T5Tokenizer from transformers import T5EncoderModel, T5Tokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment