Unverified Commit b161b913 authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

fix import error (#573)

parent 0e08595c
......@@ -292,4 +292,3 @@ pipe.generate(
<div align="center">
由 LightX2V 团队用 ❤️ 构建
</div>
......@@ -3,7 +3,12 @@ import math
import os
import torch
from transformers import Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration
try:
from transformers import Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration
except ImportError:
Qwen2Tokenizer = None
Qwen2_5_VLForConditionalGeneration = None
from lightx2v_platform.base.global_var import AI_DEVICE
......
......@@ -24,6 +24,7 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
SglQuantLinearFp8, # noqa E402
TorchaoQuantLinearInt8, # noqa E402
VllmQuantLinearInt8, # noqa E402,
VllmQuantLinearFp8, # noqa E402
)
from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8 # noqa E402
from lightx2v.models.input_encoders.hf.wan.t5.tokenizer import HuggingfaceTokenizer # noqa E402
......@@ -195,6 +196,8 @@ class T5Attention(nn.Module):
linear_cls = VllmQuantLinearInt8
elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8
elif quant_scheme == "fp8-vllm":
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
......@@ -268,6 +271,8 @@ class T5FeedForward(nn.Module):
linear_cls = VllmQuantLinearInt8
elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8
elif quant_scheme == "fp8-vllm":
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
......
......@@ -10,7 +10,7 @@ from loguru import logger
# from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearInt8
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.utils import load_weights
from lightx2v_platform.base.global_var import AI_DEVICE
from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8
......@@ -65,6 +65,8 @@ class SelfAttention(nn.Module):
linear_cls = VllmQuantLinearInt8
elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8
elif quant_scheme == "fp8-vllm":
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
......@@ -147,6 +149,8 @@ class AttentionBlock(nn.Module):
linear_cls = VllmQuantLinearInt8
elif quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8
elif quant_scheme == "fp8-vllm":
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
......
......@@ -8,9 +8,13 @@ try:
FLASH_ATTN_3_AVAILABLE = True
except ImportError:
from flash_attn import flash_attn_func
try:
from flash_attn import flash_attn_func
FLASH_ATTN_3_AVAILABLE = False
FLASH_ATTN_3_AVAILABLE = False
except ImportError:
FLASH_ATTN_3_AVAILABLE = False
from lightx2v.models.networks.wan.infer.matrix_game2.posemb_layers import apply_rotary_emb, get_nd_rotary_pos_embed
......
......@@ -113,12 +113,14 @@ class LightX2VPipeline:
boundary_step_index=2,
denoising_step_list=[1000, 750, 500, 250],
config_json=None,
rope_type="torch",
):
if config_json is not None:
self.set_infer_config_json(config_json)
else:
self.set_infer_config(
attn_mode,
rope_type,
infer_steps,
num_frames,
height,
......@@ -142,6 +144,7 @@ class LightX2VPipeline:
def set_infer_config(
self,
attn_mode,
rope_type,
infer_steps,
num_frames,
height,
......@@ -164,7 +167,7 @@ class LightX2VPipeline:
self.enable_cfg = False
else:
self.enable_cfg = True
self.rope_type = rope_type
self.fps = fps
self.aspect_ratio = aspect_ratio
self.boundary = boundary
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment