Commit 57b0ad8e authored by lifu's avatar lifu
Browse files

add qwen int8

parent 5e2c95b7
......@@ -131,9 +131,7 @@ class BaseModel(torch.nn.Module):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
#operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
#rndi
int8 = model_config.optimizations.get("int8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8,int8_optimizations=int8)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
......
......@@ -24,6 +24,9 @@ import comfy.float
import comfy.rmsnorm
import contextlib
import triton
import triton.language as tl
from triton.language.extra import libdevice
try:
from lmslim import quant_ops
......@@ -318,7 +321,7 @@ class manual_cast(disable_weight_init):
from typing import Optional
class manual_cast_int8_per_channel(manual_cast):
class manual_cast_int8(manual_cast):
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=None, device=None):
super().__init__()
......@@ -365,8 +368,6 @@ class manual_cast_int8_per_channel(manual_cast):
return w_q, scales
def forward(self, input):
#return self.forward_calibration(input)
dim = input.dim()
if dim > 2:
input = input.squeeze(0)
......@@ -383,45 +384,87 @@ class manual_cast_int8_per_channel(manual_cast):
return output_tensor
class manual_cast_int8(manual_cast):
class Linear(torch.nn.Module, CastWeightBiasOp):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
@triton.jit
def _per_token_quant_int8(
x_ptr,
xq_ptr,
s_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
BLOCK: tl.constexpr,
):
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
s = tl.load(s_ptr + cols, mask=mask, other=0.0).to(tl.float32)
x = x * s
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
def per_token_quant_int8_smooth(x, s):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
_per_token_quant_int8[(M, )](
x,
x_q,
s,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
class manual_cast_int8_smooth(manual_cast):
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=None, device=None):
super().__init__()
print("=============use int8==============")
self.in_features = in_features
self.out_features = out_features
# self.weight = Parameter(torch.empty((out_features, in_features),dtype=torch.int8, device=device))
# self.weight_scale = Parameter(torch.empty((out_features,1), **factory_kwargs))
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8, device=device))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float16, device=device))
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), dtype=dtype, device=device), requires_grad=False)
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features,dtype=torch.float16, device=device))
self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device=device))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
return None
def verify_quant_gemm(self,input_q,weight_q,input_scale, weight_scale,out_dtype: torch.dtype,
bias):
self.register_parameter("bias", None)
# 2. INT GEMM
# (int8 matmul -> cast to int32 accumulated result)
y_q = (input_q.cpu().int() @ (weight_q.cpu().int().t()))
self.weight_quant = None
self.weight_scale = None
self.scales_rcp = None
# 3. Dequantize
y_deq = y_q * ((input_scale * weight_scale.t()).cpu())
self.act_scales = None
self.count = 0
self.alpha = 0.6
# 4. Reference FP32 GEMM
return y_deq.to(out_dtype).cuda()
self.scales = torch.nn.Parameter(torch.empty(in_features, dtype=dtype, device=device), requires_grad=False)
def blaslt_scaled_mm(self,
a: torch.Tensor,
......@@ -429,80 +472,108 @@ class manual_cast_int8(manual_cast):
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias) -> torch.Tensor:
# b = b.t()
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
m = a.shape[0]
n = b.shape[0]
k = a.shape[1]
# import pdb
# pdb.set_trace()
stat, output = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype)
# output = matmul_int8(a, scale_a, b, scale_b, out_dtype, config=None)
# status, output = torch.ops.lmslim.lightop_channel_int8_mm(a, b, scale_a, scale_b, out_dtype, bias)
_, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a.to(torch.float32), scale_b.to(torch.float32), m, n, k, 'NT', out_dtype)
if bias is not None:
output += bias
# torch.cuda.synchronize()
# out = torch.rand((m, n),dtype=torch.bfloat16, device=a.device)
return output
def quantize_symmetric_per_row_int8(self, x: torch.Tensor):
"""
对输入 x 进行 per-row(dim=1)对称 INT8 量化。
Args:
x: tensor of shape [B, N], dtype in {float32, float16, bfloat16}
Returns:
x_q: quantized int8 tensor, shape [B, N]
scales: scale per row, shape [B, 1], same dtype as x
"""
assert x.ndim == 2, f"Expected 2D input, got {x.shape}"
assert x.dtype in [torch.float32, torch.float16, torch.bfloat16]
# Step 1: 计算每行的最大绝对值 -> shape [B, 1]
max_abs = x.abs().amax(dim=1, keepdim=True) # keepdim=True 保证 shape [32, 1]
# Step 2: 计算 scale = max_abs / 127
# 避免除零:若某行为全零,则 scale=1
scales = torch.where(
max_abs == 0,
torch.tensor(1.0, dtype=x.dtype, device=x.device),
max_abs / 127.0
) # shape [32, 1], dtype = x.dtype
# Step 3: 量化:x_q = round(x / scales)
# 为避免 bfloat16 精度问题,中间计算用 float32
x_f32 = x.to(torch.float32)
scales_f32 = scales.to(torch.float32)
x_q_f32 = torch.round(x_f32 / scales_f32)
# Step 4: clamp 到 [-127, 127] 并转为 int8
x_q = torch.clamp(x_q_f32, -127, 127).to(torch.int8)
return x_q, scales_f32
def forward(self, input_tensor: torch.Tensor):
# import pdb
# pdb.set_trace()
dim = input_tensor.dim()
out += bias
return out
def weight_quant_int8(self, weight):
org_w_shape = weight.shape
w = weight.to(torch.bfloat16)
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
qmin, qmax = -128, 127
scales = (max_val / qmax).float()
w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales
def per_token_quant_int8_torch(self, input):
org_input_shape = input.shape
max_val = input.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
qmin, qmax = -128, 127
scales = max_val / qmax
input_q = torch.clamp(torch.round(input / scales), qmin, qmax).to(torch.int8)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(input_q).sum() == 0
return input_q, scales
def forward(self, input):
#return self.forward_calibration(input)
dim = input.dim()
if dim > 2:
input = input.squeeze(0)
if self.weight_quant is None:
weight_smooth = self.weight * self.scales
self.scales_rcp = 1.0 / self.scales
self.weight_quant, self.weight_scale = per_token_quant_int8(weight_smooth)
del self.weight
input_quant, input_scale = per_token_quant_int8_smooth(input, self.scales_rcp)
output_tensor = self.blaslt_scaled_mm(input_quant, self.weight_quant, input_scale, self.weight_scale, torch.bfloat16, self.bias)
if dim > 2:
input_tensor = input_tensor.squeeze(0)
dtype = input_tensor.dtype
# print
# import pdb
# pdb.set_trace()
input_tensor_quant, input_tensor_scale = per_token_quant_int8(input_tensor)
# input_tensor_quant, input_tensor_scale = self.quantize_symmetric_per_row_int8(input_tensor)
output_tensor = output_tensor.unsqueeze(0)
output_tensor = self.blaslt_scaled_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.to(torch.float32), dtype, self.bias)
# output_sf = self.verify_quant_gemm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.to(torch.float32), dtype, self.bias)
return output_tensor
def forward_calibration(self, input):
dim = input.dim()
if dim > 2:
input = input.squeeze(0)
if self.count < 48:
self.calibration(input)
output_tensor = torch.mm(input, self.weight.to(torch.bfloat16).t())
if self.bias is not None:
output_tensor += self.bias.to(torch.bfloat16)
if dim > 2:
output_tensor = output_tensor.unsqueeze(0)
return output_tensor
def extra_repr(self) -> str:
return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'
def calibration(self, input):
self.count += 1
if self.count == 1:
self.weight_max = torch.max(self.weight.to(torch.bfloat16), dim=0)[0].clamp(min=1e-5).cpu()
if self.count <= 48:
tensor = input.abs()
comming_max = torch.max(tensor, dim=0)[0].cpu()
if self.act_scales is not None:
self.act_scales = torch.max(self.act_scales, comming_max)
else:
self.act_scales = comming_max
if self.count == 48:
print(f"====================================={self.count}==========================================")
print(f"weight dtype: {self.weight.dtype} bias : {self.bias.dtype}")
# print("act_max: ",self.act_scales)
# print("weight_max: ",self.weight_max)
self.scales.data = (torch.pow(self.act_scales, self.alpha) / torch.pow(self.weight_max, 1 - self.alpha)).clamp(min=1e-5).cuda()
# print("pow(|act_max|, alpha) / pow(|weight_max|, 1-alpha): ",self.scales)
# print(f"scales min: {self.scales.min().item()}, max: {self.scales.max().item()}")
# print(f"scales has NaN: {torch.any(torch.isnan(self.scales))}")
# print(f"scales has INF: {torch.any(torch.isinf(self.scales))}")
# print(f"scales has zero: {torch.any(self.scales == 0)}")
def fp8_linear(self, input):
dtype = self.weight.dtype
......@@ -636,9 +707,12 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, int8_optimizations=None):
if int8_optimizations is not None and int8_optimizations:
return manual_cast_int8_per_channel
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if model_config is not None and model_config.optimizations.get("int8", False):
if model_config.unet_config.get("image_model", "") == "qwen_image":
return manual_cast_int8_smooth
return manual_cast_int8
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
......
/root/models/
\ No newline at end of file
/home/models
\ No newline at end of file
......@@ -912,8 +912,9 @@ class UNETLoader:
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e4m3fn_fast":
print("##### PANN_DEBUG UNETLoader fp8_e4m3fn_fast ####")
model_options["dtype"] = torch.float8_e4m3fn
if unet_name == "Qwen-Image-Edit-2509-smooth-int8.safetensors":
model_options["dtype"] = torch.bfloat16
#model_options["fp8_optimizations"] = True
model_options["int8_optimizations"] = True
elif weight_dtype == "fp8_e5m2":
......@@ -922,7 +923,6 @@ class UNETLoader:
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
#model.model = model.model.to(memory_format=torch.channels_last)
#print(model.model)
return (model,)
class CLIPLoader:
......
......@@ -462,9 +462,11 @@ def test6(server_url: str):
image_mapping = {}
for weight_dtype in ['default', 'fp8_e4m3fn', 'fp8_e4m3fn_fast']:
#for weight_dtype in ['fp8_e4m3fn_fast']:
logger.info(f'\n========> {workflow_name} {weight_dtype} <========')
if weight_dtype == "fp8_e4m3fn_fast":
api_prompt["236"]["inputs"]["unet_name"] = "Qwen-Image-Edit-2509-smooth-int8.safetensors"
recorder = TimingRecorder()
for idx, (image1, image2) in enumerate(test_cases):
api_prompt["247"]["inputs"]["image"] = image1
api_prompt["248"]["inputs"]["image"] = image2
......@@ -613,7 +615,7 @@ if __name__ == "__main__":
#test6(server_url)
# Test old photo restoration workflow
#test7(server_url)
test7(server_url)
#test7(server_url)
#test7(server_url)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment