Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
cc2a283a
"vscode:/vscode.git/clone" did not exist on "550b6a306564a06b821d76bebe6768e3c707c766"
Commit
cc2a283a
authored
Jul 21, 2025
by
gushiqiao
Committed by
GitHub
Jul 21, 2025
Browse files
Add torchao kernel
Dev quant
parents
29a90944
adf8df9d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
140 additions
and
34 deletions
+140
-34
configs/quantization/wan_i2v_torchao.json
configs/quantization/wan_i2v_torchao.json
+21
-0
docs/EN/source/deploy_guides/deploy_local_windows.md
docs/EN/source/deploy_guides/deploy_local_windows.md
+0
-16
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+35
-0
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+54
-3
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+9
-5
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+9
-5
lightx2v/models/networks/wan/distill_model.py
lightx2v/models/networks/wan/distill_model.py
+4
-0
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+8
-5
No files found.
configs/quantization/wan_i2v_torchao.json
0 → 100755
View file @
cc2a283a
{
"infer_steps"
:
4
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"sage_attn2"
,
"cross_attn_1_type"
:
"sage_attn2"
,
"cross_attn_2_type"
:
"sage_attn2"
,
"seed"
:
42
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"mm_config"
:
{
"mm_type"
:
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao"
},
"t5_quantized"
:
true
,
"t5_quant_scheme"
:
"int8-torchao"
,
"clip_quantized"
:
true
,
"clip_quant_scheme"
:
"int8-torchao"
}
docs/EN/source/deploy_guides/deploy_local_windows.md
View file @
cc2a283a
...
...
@@ -19,22 +19,6 @@ This document provides detailed instructions for deploying LightX2V locally on W
-
**CUDA**
: 12.4 or higher version
-
**Dependencies**
: Refer to LightX2V project's requirements_win.txt
### Installation Steps
1.
**Clone Project**
```
cmd
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
```
2.
**Install Dependencies**
```
cmd
pip install -r requirements_win.txt
```
3.
**Download Models**
Refer to
[
Model Download Guide
](
../getting_started/quickstart.md
)
to download required models
## 🎯 Usage Methods
### Method 1: Using Batch File Inference
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
cc2a283a
...
...
@@ -25,6 +25,11 @@ try:
except
ImportError
:
deep_gemm
=
None
try
:
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
except
ModuleNotFoundError
:
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
...
...
@@ -232,6 +237,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
# =========================
# act quant kernels
# =========================
def
act_quant_int8_perchannel_sym_torchao
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
quantize_activation_per_token_absmax
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_fp8_perchannel_sym_vllm
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
ops
.
scaled_fp8_quant
(
x
,
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
True
)
...
...
@@ -624,6 +632,33 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao"
)
class
MMWeightWint8channelAint8channeldynamicSglActVllm
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Torchao
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_torchao
def
apply
(
self
,
input_tensor
):
input_tensor
=
input_tensor
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
quant_int8_per_token_matmul
(
input_tensor_quant
,
input_tensor_scale
,
self
.
weight
,
self
.
weight_scale
.
t
().
float
(),
output_dtype
=
torch
.
bfloat16
)
if
self
.
bias
is
not
None
:
output_tensor
=
output_tensor
+
self
.
bias
return
output_tensor
if
__name__
==
"__main__"
:
weight_dict
=
{
"xx.weight"
:
torch
.
randn
(
8192
,
4096
).
to
(
torch
.
float8_e4m3fn
),
...
...
lightx2v/models/input_encoders/hf/q_linear.py
View file @
cc2a283a
import
torch
import
torch.nn
as
nn
from
vllm
import
_custom_ops
as
ops
try
:
from
vllm
import
_custom_ops
as
ops
except
ModuleNotFoundError
:
ops
=
None
class
QuantLinearInt8
(
nn
.
Module
):
try
:
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
except
ModuleNotFoundError
:
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
class
VllmQuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
self
.
in_features
=
in_features
...
...
@@ -54,7 +63,7 @@ class QuantLinearInt8(nn.Module):
return
self
class
QuantLinearFp8
(
nn
.
Module
):
class
Vllm
QuantLinearFp8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
self
.
in_features
=
in_features
...
...
@@ -101,3 +110,45 @@ class QuantLinearFp8(nn.Module):
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
class
TorchaoQuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
register_buffer
(
"weight"
,
torch
.
empty
((
out_features
,
in_features
),
dtype
=
torch
.
int8
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
dtype
))
else
:
self
.
register_buffer
(
"bias"
,
None
)
def
act_quant_func
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
quantize_activation_per_token_absmax
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
forward
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
squeeze
(
0
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
quant_int8_per_token_matmul
(
input_tensor_quant
,
input_tensor_scale
,
self
.
weight
.
t
(),
self
.
weight_scale
.
t
().
float
(),
output_dtype
=
torch
.
bfloat16
)
if
self
.
bias
is
not
None
:
output_tensor
=
output_tensor
+
self
.
bias
return
output_tensor
.
unsqueeze
(
0
)
def
_apply
(
self
,
fn
):
for
module
in
self
.
children
():
module
.
_apply
(
fn
)
def
maybe_cast
(
t
):
if
t
is
not
None
and
t
.
device
!=
fn
(
t
).
device
:
return
fn
(
t
)
return
t
self
.
weight
=
maybe_cast
(
self
.
weight
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
lightx2v/models/input_encoders/hf/t5/model.py
View file @
cc2a283a
...
...
@@ -9,7 +9,7 @@ import torch.nn.functional as F
from
.tokenizer
import
HuggingfaceTokenizer
from
loguru
import
logger
from
lightx2v.models.input_encoders.hf.q_linear
import
QuantLinearInt8
,
QuantLinearFp8
from
lightx2v.models.input_encoders.hf.q_linear
import
Vllm
QuantLinearInt8
,
Vllm
QuantLinearFp8
,
TorchaoQuantLinearInt8
__all__
=
[
...
...
@@ -83,9 +83,11 @@ class T5Attention(nn.Module):
if
quantized
:
if
quant_scheme
==
"int8"
:
linear_cls
=
QuantLinearInt8
linear_cls
=
Vllm
QuantLinearInt8
elif
quant_scheme
==
"fp8"
:
linear_cls
=
QuantLinearFp8
linear_cls
=
VllmQuantLinearFp8
elif
quant_scheme
==
"int8-torchao"
:
linear_cls
=
TorchaoQuantLinearInt8
else
:
linear_cls
=
nn
.
Linear
...
...
@@ -144,9 +146,11 @@ class T5FeedForward(nn.Module):
if
quantized
:
if
quant_scheme
==
"int8"
:
linear_cls
=
QuantLinearInt8
linear_cls
=
Vllm
QuantLinearInt8
elif
quant_scheme
==
"fp8"
:
linear_cls
=
QuantLinearFp8
linear_cls
=
VllmQuantLinearFp8
elif
quant_scheme
==
"int8-torchao"
:
linear_cls
=
TorchaoQuantLinearInt8
else
:
linear_cls
=
nn
.
Linear
# layers
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
cc2a283a
...
...
@@ -10,7 +10,7 @@ import torchvision.transforms as T
from
lightx2v.attentions
import
attention
from
loguru
import
logger
from
lightx2v.models.input_encoders.hf.q_linear
import
QuantLinearInt8
,
QuantLinearFp8
from
lightx2v.models.input_encoders.hf.q_linear
import
Vllm
QuantLinearInt8
,
Vllm
QuantLinearFp8
,
TorchaoQuantLinearInt8
from
einops
import
rearrange
from
torch
import
Tensor
from
transformers
import
CLIPVisionModel
...
...
@@ -63,9 +63,11 @@ class SelfAttention(nn.Module):
# layers
if
quantized
:
if
quant_scheme
==
"int8"
:
linear_cls
=
QuantLinearInt8
linear_cls
=
Vllm
QuantLinearInt8
elif
quant_scheme
==
"fp8"
:
linear_cls
=
QuantLinearFp8
linear_cls
=
VllmQuantLinearFp8
elif
quant_scheme
==
"int8-torchao"
:
linear_cls
=
TorchaoQuantLinearInt8
else
:
linear_cls
=
nn
.
Linear
...
...
@@ -135,9 +137,11 @@ class AttentionBlock(nn.Module):
# layers
if
quantized
:
if
quant_scheme
==
"int8"
:
linear_cls
=
QuantLinearInt8
linear_cls
=
Vllm
QuantLinearInt8
elif
quant_scheme
==
"fp8"
:
linear_cls
=
QuantLinearFp8
linear_cls
=
VllmQuantLinearFp8
elif
quant_scheme
==
"int8-torchao"
:
linear_cls
=
TorchaoQuantLinearInt8
else
:
linear_cls
=
nn
.
Linear
...
...
lightx2v/models/networks/wan/distill_model.py
View file @
cc2a283a
...
...
@@ -11,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights
,
)
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
class
WanDistillModel
(
WanModel
):
...
...
@@ -31,8 +32,11 @@ class WanDistillModel(WanModel):
return
weight_dict
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
f
"
{
ckpt_folder
}
/distill_model.pt"
)
if
os
.
path
.
exists
(
ckpt_path
):
logger
.
info
(
f
"Loading weights from
{
ckpt_path
}
"
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
print
(
weight_dict
.
keys
())
weight_dict
=
{
key
:
(
weight_dict
[
key
].
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
skip_bf16
)
else
weight_dict
[
key
]).
pin_memory
().
to
(
self
.
device
)
for
key
in
weight_dict
.
keys
()
}
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
cc2a283a
...
...
@@ -57,17 +57,19 @@ class WanRunner(DefaultRunner):
if
clip_quantized
:
clip_quant_scheme
=
self
.
config
.
get
(
"clip_quant_scheme"
,
None
)
assert
clip_quant_scheme
is
not
None
tmp_clip_quant_scheme
=
clip_quant_scheme
.
split
(
"-"
)[
0
]
clip_quantized_ckpt
=
self
.
config
.
get
(
"clip_quantized_ckpt"
,
os
.
path
.
join
(
os
.
path
.
join
(
self
.
config
.
model_path
,
clip_quant_scheme
),
f
"clip-
{
clip_quant_scheme
}
.pth"
,
os
.
path
.
join
(
self
.
config
.
model_path
,
tmp_
clip_quant_scheme
),
f
"clip-
{
tmp_
clip_quant_scheme
}
.pth"
,
),
)
else
:
clip_quantized_ckpt
=
None
clip_quant_scheme
=
None
print
(
clip_quant_scheme
)
image_encoder
=
CLIPModel
(
dtype
=
torch
.
float16
,
device
=
self
.
init_device
,
...
...
@@ -93,18 +95,19 @@ class WanRunner(DefaultRunner):
t5_quantized
=
self
.
config
.
get
(
"t5_quantized"
,
False
)
if
t5_quantized
:
t5_quant_scheme
=
self
.
config
.
get
(
"t5_quant_scheme"
,
None
)
tmp_t5_quant_scheme
=
t5_quant_scheme
.
split
(
"-"
)[
0
]
assert
t5_quant_scheme
is
not
None
t5_quantized_ckpt
=
self
.
config
.
get
(
"t5_quantized_ckpt"
,
os
.
path
.
join
(
os
.
path
.
join
(
self
.
config
.
model_path
,
t5_quant_scheme
),
f
"models_t5_umt5-xxl-enc-
{
t5_quant_scheme
}
.pth"
,
os
.
path
.
join
(
self
.
config
.
model_path
,
tmp_
t5_quant_scheme
),
f
"models_t5_umt5-xxl-enc-
{
tmp_
t5_quant_scheme
}
.pth"
,
),
)
else
:
t5_quant_scheme
=
None
t5_quantized_ckpt
=
None
print
(
t5_quant_scheme
)
text_encoder
=
T5EncoderModel
(
text_len
=
self
.
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment