Commit f6d4fc85 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

style: add ruff isort (#183)

parent 878f5a48
import os
import gradio as gr
import argparse import argparse
import json
import torch
import gc import gc
from easydict import EasyDict import glob
import importlib.util
import json
import os
import random
from datetime import datetime from datetime import datetime
from loguru import logger
import importlib.util import gradio as gr
import psutil import psutil
import random import torch
import glob from easydict import EasyDict
from loguru import logger
logger.add( logger.add(
"inference_logs.log", "inference_logs.log",
......
import os
import gradio as gr
import argparse import argparse
import json
import torch
import gc import gc
from easydict import EasyDict import glob
import importlib.util
import json
import os
import random
from datetime import datetime from datetime import datetime
from loguru import logger
import importlib.util import gradio as gr
import psutil import psutil
import random import torch
import glob from easydict import EasyDict
from loguru import logger
logger.add( logger.add(
"inference_logs.log", "inference_logs.log",
......
...@@ -15,9 +15,8 @@ import os ...@@ -15,9 +15,8 @@ import os
import sys import sys
from typing import List from typing import List
from sphinx.ext import autodoc
import sphinxcontrib.redoc import sphinxcontrib.redoc
from sphinx.ext import autodoc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
sys.path.append(os.path.abspath("../..")) sys.path.append(os.path.abspath("../.."))
......
...@@ -15,9 +15,8 @@ import os ...@@ -15,9 +15,8 @@ import os
import sys import sys
from typing import List from typing import List
from sphinx.ext import autodoc
import sphinxcontrib.redoc import sphinxcontrib.redoc
from sphinx.ext import autodoc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
sys.path.append(os.path.abspath("../..")) sys.path.append(os.path.abspath("../.."))
......
...@@ -15,9 +15,8 @@ import os ...@@ -15,9 +15,8 @@ import os
import sys import sys
from typing import List from typing import List
from sphinx.ext import autodoc
import sphinxcontrib.redoc import sphinxcontrib.redoc
from sphinx.ext import autodoc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
sys.path.append(os.path.abspath("../..")) sys.path.append(os.path.abspath("../.."))
......
import argparse import argparse
import concurrent.futures
import os
import socket
import subprocess import subprocess
import time import time
import socket from dataclasses import dataclass
import os from typing import Dict, List, Optional
from typing import List, Optional, Dict
import psutil import psutil
import requests import requests
from loguru import logger from loguru import logger
import concurrent.futures
from dataclasses import dataclass
@dataclass @dataclass
......
import argparse import argparse
import sys
import signal
import atexit import atexit
import signal
import sys
from pathlib import Path from pathlib import Path
from loguru import logger
import uvicorn import uvicorn
from loguru import logger
from lightx2v.server.api import ApiServer from lightx2v.server.api import ApiServer
from lightx2v.server.service import DistributedInferenceService from lightx2v.server.service import DistributedInferenceService
......
import argparse import argparse
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json import json
import os import os
from typing import Optional
import torch import torch
from lightx2v.common.ops import * import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.service_utils import BaseServiceStatus, ImageTransporter, ProcessManager, TaskStatusMessage, TensorTransporter
from lightx2v.utils.set_config import set_config from lightx2v.utils.set_config import set_config
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager, TensorTransporter, ImageTransporter
tensor_transporter = TensorTransporter() tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter() image_transporter = ImageTransporter()
......
import argparse import argparse
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json import json
import os import os
import torch import torch
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.service_utils import BaseServiceStatus, ImageTransporter, ProcessManager, TaskStatusMessage, TensorTransporter
from lightx2v.utils.set_config import set_config from lightx2v.utils.set_config import set_config
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager, TensorTransporter, ImageTransporter
tensor_transporter = TensorTransporter() tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter() image_transporter = ImageTransporter()
......
import argparse import argparse
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json import json
from typing import Optional from typing import Optional
import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager from lightx2v.utils.service_utils import BaseServiceStatus, ProcessManager, TaskStatusMessage
# ========================= # =========================
# FastAPI Related Code # FastAPI Related Code
......
import argparse import argparse
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json import json
import os import os
from typing import Optional
import torch import torch
import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.service_utils import BaseServiceStatus, ImageTransporter, ProcessManager, TaskStatusMessage, TensorTransporter
from lightx2v.utils.set_config import set_config from lightx2v.utils.set_config import set_config
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager, TensorTransporter, ImageTransporter
tensor_transporter = TensorTransporter() tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter() image_transporter = ImageTransporter()
......
import argparse import argparse
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
from typing import Optional
import numpy as np
import uvicorn
import json import json
import os import os
from typing import Optional
import numpy as np
import torch import torch
import torchvision import torchvision
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from lightx2v.common.ops import * import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.service_utils import BaseServiceStatus, ImageTransporter, ProcessManager, TaskStatusMessage, TensorTransporter
from lightx2v.utils.set_config import set_config from lightx2v.utils.set_config import set_config
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager, TensorTransporter, ImageTransporter
tensor_transporter = TensorTransporter() tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter() image_transporter = ImageTransporter()
......
...@@ -19,9 +19,10 @@ import argparse ...@@ -19,9 +19,10 @@ import argparse
import os import os
import tensorrt as trt import tensorrt as trt
from .common_runtime import *
from loguru import logger from loguru import logger
from .common_runtime import *
try: try:
# Sometimes python does not understand FileNotFoundError # Sometimes python does not understand FileNotFoundError
FileNotFoundError FileNotFoundError
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# #
import ctypes import ctypes
from typing import Optional, List, Union from typing import List, Optional, Union
import numpy as np import numpy as np
import tensorrt as trt import tensorrt as trt
......
import torch import gc
import threading
import queue import queue
import threading
import time import time
import gc
from loguru import logger
from collections import OrderedDict from collections import OrderedDict
import torch
from loguru import logger
class WeightAsyncStreamManager(object): class WeightAsyncStreamManager(object):
def __init__(self, blocks_num, offload_ratio=1, phases_num=1): def __init__(self, blocks_num, offload_ratio=1, phases_num=1):
......
from .attn import *
from .conv import *
from .mm import * from .mm import *
from .norm import * from .norm import *
from .conv import *
from .tensor import * from .tensor import *
from .attn import *
...@@ -2,6 +2,6 @@ from .flash_attn import * ...@@ -2,6 +2,6 @@ from .flash_attn import *
from .radial_attn import * from .radial_attn import *
from .ring_attn import * from .ring_attn import *
from .sage_attn import * from .sage_attn import *
from .sparge_attn import *
from .torch_sdpa import * from .torch_sdpa import *
from .ulysses_attn import * from .ulysses_attn import *
from .sparge_attn import *
...@@ -13,9 +13,10 @@ except ImportError: ...@@ -13,9 +13,10 @@ except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None flash_attn_varlen_func_v3 = None
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@ATTN_WEIGHT_REGISTER("flash_attn2") @ATTN_WEIGHT_REGISTER("flash_attn2")
class FlashAttn2Weight(AttnWeightTemplate): class FlashAttn2Weight(AttnWeightTemplate):
......
import torch import torch
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
try: try:
import flashinfer import flashinfer
from packaging import version from packaging import version
......
import torch import torch
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
import torch.distributed as dist import torch.distributed as dist
from .utils.ring_comm import RingComm
import torch.nn.functional as F import torch.nn.functional as F
from loguru import logger from loguru import logger
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
from .utils.ring_comm import RingComm
try: try:
import flash_attn import flash_attn
from flash_attn.flash_attn_interface import flash_attn_varlen_func from flash_attn.flash_attn_interface import flash_attn_varlen_func
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment