Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
dcu-Comfyui
Commits
57b0ad8e
Commit
57b0ad8e
authored
Jan 07, 2026
by
lifu
Browse files
add qwen int8
parent
5e2c95b7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
188 additions
and
114 deletions
+188
-114
comfy/model_base.py
comfy/model_base.py
+1
-3
comfy/ops.py
comfy/ops.py
+180
-106
models
models
+1
-1
nodes.py
nodes.py
+2
-2
workflow_test/test_c.py
workflow_test/test_c.py
+4
-2
No files found.
comfy/model_base.py
View file @
57b0ad8e
...
...
@@ -131,9 +131,7 @@ class BaseModel(torch.nn.Module):
if
model_config
.
custom_operations
is
None
:
fp8
=
model_config
.
optimizations
.
get
(
"fp8"
,
False
)
#operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
#rndi
int8
=
model_config
.
optimizations
.
get
(
"int8"
,
False
)
operations
=
comfy
.
ops
.
pick_operations
(
unet_config
.
get
(
"dtype"
,
None
),
self
.
manual_cast_dtype
,
fp8_optimizations
=
fp8
,
scaled_fp8
=
model_config
.
scaled_fp8
,
int8_optimizations
=
int8
)
operations
=
comfy
.
ops
.
pick_operations
(
unet_config
.
get
(
"dtype"
,
None
),
self
.
manual_cast_dtype
,
fp8_optimizations
=
fp8
,
scaled_fp8
=
model_config
.
scaled_fp8
,
model_config
=
model_config
)
else
:
operations
=
model_config
.
custom_operations
self
.
diffusion_model
=
unet_model
(
**
unet_config
,
device
=
device
,
operations
=
operations
)
...
...
comfy/ops.py
View file @
57b0ad8e
...
...
@@ -24,6 +24,9 @@ import comfy.float
import
comfy.rmsnorm
import
contextlib
import
triton
import
triton.language
as
tl
from
triton.language.extra
import
libdevice
try
:
from
lmslim
import
quant_ops
...
...
@@ -318,7 +321,7 @@ class manual_cast(disable_weight_init):
from
typing
import
Optional
class
manual_cast_int8
_per_channel
(
manual_cast
):
class
manual_cast_int8
(
manual_cast
):
class
Linear
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
...
...
@@ -365,8 +368,6 @@ class manual_cast_int8_per_channel(manual_cast):
return
w_q
,
scales
def
forward
(
self
,
input
):
#return self.forward_calibration(input)
dim
=
input
.
dim
()
if
dim
>
2
:
input
=
input
.
squeeze
(
0
)
...
...
@@ -383,45 +384,87 @@ class manual_cast_int8_per_channel(manual_cast):
return
output_tensor
class
manual_cast_int8
(
manual_cast
):
class
Linear
(
torch
.
nn
.
Module
,
CastWeightBiasOp
):
__constants__
=
[
'in_features'
,
'out_features'
]
in_features
:
int
out_features
:
int
weight
:
torch
.
Tensor
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
@
triton
.
jit
def
_per_token_quant_int8
(
x_ptr
,
xq_ptr
,
s_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
BLOCK
:
tl
.
constexpr
,
):
row_id
=
tl
.
program_id
(
0
)
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
s
=
tl
.
load
(
s_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
x
=
x
*
s
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
def
per_token_quant_int8_smooth
(
x
,
s
):
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
_per_token_quant_int8
[(
M
,
)](
x
,
x_q
,
s
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
class
manual_cast_int8_smooth
(
manual_cast
):
class
Linear
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
print
(
"=============use int8=============="
)
self
.
in_features
=
in_features
self
.
out_features
=
out_features
# self.weight = Parameter(torch.empty((out_features, in_features),dtype=torch.int8, device=device))
# self.weight_scale = Parameter(torch.empty((out_features,1), **factory_kwargs))
self
.
register_buffer
(
"weight"
,
torch
.
empty
((
out_features
,
in_features
),
dtype
=
torch
.
int8
,
device
=
device
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float16
,
device
=
device
))
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
out_features
,
in_features
),
dtype
=
dtype
,
device
=
device
),
requires_grad
=
False
)
if
bias
:
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
out_features
,
dtype
=
torch
.
float16
,
device
=
device
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
out_features
,
dtype
=
dtype
,
device
=
device
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
)
->
None
:
return
None
def
verify_quant_gemm
(
self
,
input_q
,
weight_q
,
input_scale
,
weight_scale
,
out_dtype
:
torch
.
dtype
,
bias
):
self
.
register_parameter
(
"bias"
,
None
)
# 2. INT GEMM
# (int8 matmul -> cast to int32 accumulated result)
y_q
=
(
input_q
.
cpu
().
int
()
@
(
weight_q
.
cpu
().
int
().
t
()))
self
.
weight_quant
=
None
self
.
weight_scale
=
None
self
.
scales_rcp
=
None
# 3. Dequantize
y_deq
=
y_q
*
((
input_scale
*
weight_scale
.
t
()).
cpu
())
self
.
act_scales
=
None
self
.
count
=
0
self
.
alpha
=
0.6
# 4. Reference FP32 GEMM
return
y_deq
.
to
(
out_dtype
).
cuda
()
self
.
scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
in_features
,
dtype
=
dtype
,
device
=
device
),
requires_grad
=
False
)
def
blaslt_scaled_mm
(
self
,
a
:
torch
.
Tensor
,
...
...
@@ -429,80 +472,108 @@ class manual_cast_int8(manual_cast):
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
)
->
torch
.
Tensor
:
# b = b.t()
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
0
]
k
=
a
.
shape
[
1
]
# import pdb
# pdb.set_trace()
stat
,
output
=
quant_ops
.
hipblaslt_w8a8_gemm
(
a
,
b
,
scale_a
,
scale_b
,
m
,
n
,
k
,
'NT'
,
out_dtype
)
# output = matmul_int8(a, scale_a, b, scale_b, out_dtype, config=None)
# status, output = torch.ops.lmslim.lightop_channel_int8_mm(a, b, scale_a, scale_b, out_dtype, bias)
_
,
out
=
quant_ops
.
hipblaslt_w8a8_gemm
(
a
,
b
,
scale_a
.
to
(
torch
.
float32
),
scale_b
.
to
(
torch
.
float32
),
m
,
n
,
k
,
'NT'
,
out_dtype
)
if
bias
is
not
None
:
output
+=
bias
# torch.cuda.synchronize()
# out = torch.rand((m, n),dtype=torch.bfloat16, device=a.device)
return
output
def
quantize_symmetric_per_row_int8
(
self
,
x
:
torch
.
Tensor
):
"""
对输入 x 进行 per-row(dim=1)对称 INT8 量化。
Args:
x: tensor of shape [B, N], dtype in {float32, float16, bfloat16}
Returns:
x_q: quantized int8 tensor, shape [B, N]
scales: scale per row, shape [B, 1], same dtype as x
"""
assert
x
.
ndim
==
2
,
f
"Expected 2D input, got
{
x
.
shape
}
"
assert
x
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
# Step 1: 计算每行的最大绝对值 -> shape [B, 1]
max_abs
=
x
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
# keepdim=True 保证 shape [32, 1]
# Step 2: 计算 scale = max_abs / 127
# 避免除零:若某行为全零,则 scale=1
scales
=
torch
.
where
(
max_abs
==
0
,
torch
.
tensor
(
1.0
,
dtype
=
x
.
dtype
,
device
=
x
.
device
),
max_abs
/
127.0
)
# shape [32, 1], dtype = x.dtype
# Step 3: 量化:x_q = round(x / scales)
# 为避免 bfloat16 精度问题,中间计算用 float32
x_f32
=
x
.
to
(
torch
.
float32
)
scales_f32
=
scales
.
to
(
torch
.
float32
)
x_q_f32
=
torch
.
round
(
x_f32
/
scales_f32
)
# Step 4: clamp 到 [-127, 127] 并转为 int8
x_q
=
torch
.
clamp
(
x_q_f32
,
-
127
,
127
).
to
(
torch
.
int8
)
return
x_q
,
scales_f32
def
forward
(
self
,
input_tensor
:
torch
.
Tensor
):
# import pdb
# pdb.set_trace()
dim
=
input_tensor
.
dim
()
out
+=
bias
return
out
def
weight_quant_int8
(
self
,
weight
):
org_w_shape
=
weight
.
shape
w
=
weight
.
to
(
torch
.
bfloat16
)
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-5
)
qmin
,
qmax
=
-
128
,
127
scales
=
(
max_val
/
qmax
).
float
()
w_q
=
torch
.
clamp
(
torch
.
round
(
w
/
scales
),
qmin
,
qmax
).
to
(
torch
.
int8
)
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
w_q
).
sum
()
==
0
scales
=
scales
.
view
(
org_w_shape
[
0
],
-
1
)
w_q
=
w_q
.
reshape
(
org_w_shape
)
return
w_q
,
scales
def
per_token_quant_int8_torch
(
self
,
input
):
org_input_shape
=
input
.
shape
max_val
=
input
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-5
)
qmin
,
qmax
=
-
128
,
127
scales
=
max_val
/
qmax
input_q
=
torch
.
clamp
(
torch
.
round
(
input
/
scales
),
qmin
,
qmax
).
to
(
torch
.
int8
)
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
input_q
).
sum
()
==
0
return
input_q
,
scales
def
forward
(
self
,
input
):
#return self.forward_calibration(input)
dim
=
input
.
dim
()
if
dim
>
2
:
input
=
input
.
squeeze
(
0
)
if
self
.
weight_quant
is
None
:
weight_smooth
=
self
.
weight
*
self
.
scales
self
.
scales_rcp
=
1.0
/
self
.
scales
self
.
weight_quant
,
self
.
weight_scale
=
per_token_quant_int8
(
weight_smooth
)
del
self
.
weight
input_quant
,
input_scale
=
per_token_quant_int8_smooth
(
input
,
self
.
scales_rcp
)
output_tensor
=
self
.
blaslt_scaled_mm
(
input_quant
,
self
.
weight_quant
,
input_scale
,
self
.
weight_scale
,
torch
.
bfloat16
,
self
.
bias
)
if
dim
>
2
:
input_tensor
=
input_tensor
.
squeeze
(
0
)
dtype
=
input_tensor
.
dtype
# print
# import pdb
# pdb.set_trace()
input_tensor_quant
,
input_tensor_scale
=
per_token_quant_int8
(
input_tensor
)
# input_tensor_quant, input_tensor_scale = self.quantize_symmetric_per_row_int8(input_tensor)
output_tensor
=
output_tensor
.
unsqueeze
(
0
)
output_tensor
=
self
.
blaslt_scaled_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
.
to
(
torch
.
float32
),
dtype
,
self
.
bias
)
# output_sf = self.verify_quant_gemm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.to(torch.float32), dtype, self.bias)
return
output_tensor
def
forward_calibration
(
self
,
input
):
dim
=
input
.
dim
()
if
dim
>
2
:
input
=
input
.
squeeze
(
0
)
if
self
.
count
<
48
:
self
.
calibration
(
input
)
output_tensor
=
torch
.
mm
(
input
,
self
.
weight
.
to
(
torch
.
bfloat16
).
t
())
if
self
.
bias
is
not
None
:
output_tensor
+=
self
.
bias
.
to
(
torch
.
bfloat16
)
if
dim
>
2
:
output_tensor
=
output_tensor
.
unsqueeze
(
0
)
return
output_tensor
def
extra_repr
(
self
)
->
str
:
return
f
'in_features=
{
self
.
in_features
}
, out_features=
{
self
.
out_features
}
, bias=
{
self
.
bias
is
not
None
}
'
def
calibration
(
self
,
input
):
self
.
count
+=
1
if
self
.
count
==
1
:
self
.
weight_max
=
torch
.
max
(
self
.
weight
.
to
(
torch
.
bfloat16
),
dim
=
0
)[
0
].
clamp
(
min
=
1e-5
).
cpu
()
if
self
.
count
<=
48
:
tensor
=
input
.
abs
()
comming_max
=
torch
.
max
(
tensor
,
dim
=
0
)[
0
].
cpu
()
if
self
.
act_scales
is
not
None
:
self
.
act_scales
=
torch
.
max
(
self
.
act_scales
,
comming_max
)
else
:
self
.
act_scales
=
comming_max
if
self
.
count
==
48
:
print
(
f
"=====================================
{
self
.
count
}
=========================================="
)
print
(
f
"weight dtype:
{
self
.
weight
.
dtype
}
bias :
{
self
.
bias
.
dtype
}
"
)
# print("act_max: ",self.act_scales)
# print("weight_max: ",self.weight_max)
self
.
scales
.
data
=
(
torch
.
pow
(
self
.
act_scales
,
self
.
alpha
)
/
torch
.
pow
(
self
.
weight_max
,
1
-
self
.
alpha
)).
clamp
(
min
=
1e-5
).
cuda
()
# print("pow(|act_max|, alpha) / pow(|weight_max|, 1-alpha): ",self.scales)
# print(f"scales min: {self.scales.min().item()}, max: {self.scales.max().item()}")
# print(f"scales has NaN: {torch.any(torch.isnan(self.scales))}")
# print(f"scales has INF: {torch.any(torch.isinf(self.scales))}")
# print(f"scales has zero: {torch.any(self.scales == 0)}")
def
fp8_linear
(
self
,
input
):
dtype
=
self
.
weight
.
dtype
...
...
@@ -636,9 +707,12 @@ if CUBLAS_IS_AVAILABLE:
def
forward
(
self
,
*
args
,
**
kwargs
):
return
super
().
forward
(
*
args
,
**
kwargs
)
def
pick_operations
(
weight_dtype
,
compute_dtype
,
load_device
=
None
,
disable_fast_fp8
=
False
,
fp8_optimizations
=
False
,
scaled_fp8
=
None
,
int8_optimizations
=
None
):
if
int8_optimizations
is
not
None
and
int8_optimizations
:
return
manual_cast_int8_per_channel
def
pick_operations
(
weight_dtype
,
compute_dtype
,
load_device
=
None
,
disable_fast_fp8
=
False
,
fp8_optimizations
=
False
,
scaled_fp8
=
None
,
model_config
=
None
):
if
model_config
is
not
None
and
model_config
.
optimizations
.
get
(
"int8"
,
False
):
if
model_config
.
unet_config
.
get
(
"image_model"
,
""
)
==
"qwen_image"
:
return
manual_cast_int8_smooth
return
manual_cast_int8
fp8_compute
=
comfy
.
model_management
.
supports_fp8_compute
(
load_device
)
if
scaled_fp8
is
not
None
:
return
scaled_fp8_ops
(
fp8_matrix_mult
=
fp8_compute
and
fp8_optimizations
,
scale_input
=
fp8_optimizations
,
override_dtype
=
scaled_fp8
)
...
...
models
View file @
57b0ad8e
/root/models/
\ No newline at end of file
/home/models
\ No newline at end of file
nodes.py
View file @
57b0ad8e
...
...
@@ -912,8 +912,9 @@ class UNETLoader:
if
weight_dtype
==
"fp8_e4m3fn"
:
model_options
[
"dtype"
]
=
torch
.
float8_e4m3fn
elif
weight_dtype
==
"fp8_e4m3fn_fast"
:
print
(
"##### PANN_DEBUG UNETLoader fp8_e4m3fn_fast ####"
)
model_options
[
"dtype"
]
=
torch
.
float8_e4m3fn
if
unet_name
==
"Qwen-Image-Edit-2509-smooth-int8.safetensors"
:
model_options
[
"dtype"
]
=
torch
.
bfloat16
#model_options["fp8_optimizations"] = True
model_options
[
"int8_optimizations"
]
=
True
elif
weight_dtype
==
"fp8_e5m2"
:
...
...
@@ -922,7 +923,6 @@ class UNETLoader:
unet_path
=
folder_paths
.
get_full_path_or_raise
(
"diffusion_models"
,
unet_name
)
model
=
comfy
.
sd
.
load_diffusion_model
(
unet_path
,
model_options
=
model_options
)
#model.model = model.model.to(memory_format=torch.channels_last)
#print(model.model)
return
(
model
,)
class
CLIPLoader
:
...
...
workflow_test/test_c.py
View file @
57b0ad8e
...
...
@@ -462,9 +462,11 @@ def test6(server_url: str):
image_mapping
=
{}
for
weight_dtype
in
[
'default'
,
'fp8_e4m3fn'
,
'fp8_e4m3fn_fast'
]:
#for weight_dtype in ['fp8_e4m3fn_fast']:
logger
.
info
(
f
'
\n
========>
{
workflow_name
}
{
weight_dtype
}
<========'
)
if
weight_dtype
==
"fp8_e4m3fn_fast"
:
api_prompt
[
"236"
][
"inputs"
][
"unet_name"
]
=
"Qwen-Image-Edit-2509-smooth-int8.safetensors"
recorder
=
TimingRecorder
()
for
idx
,
(
image1
,
image2
)
in
enumerate
(
test_cases
):
api_prompt
[
"247"
][
"inputs"
][
"image"
]
=
image1
api_prompt
[
"248"
][
"inputs"
][
"image"
]
=
image2
...
...
@@ -613,7 +615,7 @@ if __name__ == "__main__":
#test6(server_url)
# Test old photo restoration workflow
#
test7(server_url)
test7
(
server_url
)
#test7(server_url)
#test7(server_url)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment