Commit 57b0ad8e authored by lifu's avatar lifu
Browse files

add qwen int8

parent 5e2c95b7
...@@ -131,9 +131,7 @@ class BaseModel(torch.nn.Module): ...@@ -131,9 +131,7 @@ class BaseModel(torch.nn.Module):
if model_config.custom_operations is None: if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False) fp8 = model_config.optimizations.get("fp8", False)
#operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) #operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
#rndi operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
int8 = model_config.optimizations.get("int8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8,int8_optimizations=int8)
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
......
...@@ -24,6 +24,9 @@ import comfy.float ...@@ -24,6 +24,9 @@ import comfy.float
import comfy.rmsnorm import comfy.rmsnorm
import contextlib import contextlib
import triton
import triton.language as tl
from triton.language.extra import libdevice
try: try:
from lmslim import quant_ops from lmslim import quant_ops
...@@ -318,7 +321,7 @@ class manual_cast(disable_weight_init): ...@@ -318,7 +321,7 @@ class manual_cast(disable_weight_init):
from typing import Optional from typing import Optional
class manual_cast_int8_per_channel(manual_cast): class manual_cast_int8(manual_cast):
class Linear(torch.nn.Module): class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=None, device=None): def __init__(self, in_features, out_features, bias=True, dtype=None, device=None):
super().__init__() super().__init__()
...@@ -365,8 +368,6 @@ class manual_cast_int8_per_channel(manual_cast): ...@@ -365,8 +368,6 @@ class manual_cast_int8_per_channel(manual_cast):
return w_q, scales return w_q, scales
def forward(self, input): def forward(self, input):
#return self.forward_calibration(input)
dim = input.dim() dim = input.dim()
if dim > 2: if dim > 2:
input = input.squeeze(0) input = input.squeeze(0)
...@@ -383,126 +384,196 @@ class manual_cast_int8_per_channel(manual_cast): ...@@ -383,126 +384,196 @@ class manual_cast_int8_per_channel(manual_cast):
return output_tensor return output_tensor
class manual_cast_int8(manual_cast): @triton.jit
class Linear(torch.nn.Module, CastWeightBiasOp): def _per_token_quant_int8(
__constants__ = ['in_features', 'out_features'] x_ptr,
in_features: int xq_ptr,
out_features: int s_ptr,
weight: torch.Tensor scale_ptr,
def __init__(self, in_features: int, out_features: int, bias: bool = True, stride_x,
device=None, dtype=None) -> None: stride_xq,
factory_kwargs = {'device': device, 'dtype': dtype} N,
BLOCK: tl.constexpr,
):
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
s = tl.load(s_ptr + cols, mask=mask, other=0.0).to(tl.float32)
x = x * s
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
def per_token_quant_int8_smooth(x, s):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
_per_token_quant_int8[(M, )](
x,
x_q,
s,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
class manual_cast_int8_smooth(manual_cast):
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=None, device=None):
super().__init__() super().__init__()
print("=============use int8==============")
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
# self.weight = Parameter(torch.empty((out_features, in_features),dtype=torch.int8, device=device)) self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), dtype=dtype, device=device), requires_grad=False)
# self.weight_scale = Parameter(torch.empty((out_features,1), **factory_kwargs))
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8, device=device))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float16, device=device))
if bias: if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features,dtype=torch.float16, device=device)) self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device=device))
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
return None
def verify_quant_gemm(self,input_q,weight_q,input_scale, weight_scale,out_dtype: torch.dtype,
bias):
# 2. INT GEMM self.weight_quant = None
# (int8 matmul -> cast to int32 accumulated result) self.weight_scale = None
y_q = (input_q.cpu().int() @ (weight_q.cpu().int().t())) self.scales_rcp = None
# 3. Dequantize self.act_scales = None
y_deq = y_q * ((input_scale * weight_scale.t()).cpu()) self.count = 0
self.alpha = 0.6
# 4. Reference FP32 GEMM self.scales = torch.nn.Parameter(torch.empty(in_features, dtype=dtype, device=device), requires_grad=False)
return y_deq.to(out_dtype).cuda()
def blaslt_scaled_mm(self, def blaslt_scaled_mm(self,
a: torch.Tensor, a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
bias) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# b = b.t()
m = a.shape[0] m = a.shape[0]
n = b.shape[0] n = b.shape[0]
k = a.shape[1] k = a.shape[1]
_, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a.to(torch.float32), scale_b.to(torch.float32), m, n, k, 'NT', out_dtype)
# import pdb
# pdb.set_trace()
stat, output = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype)
# output = matmul_int8(a, scale_a, b, scale_b, out_dtype, config=None)
# status, output = torch.ops.lmslim.lightop_channel_int8_mm(a, b, scale_a, scale_b, out_dtype, bias)
if bias is not None: if bias is not None:
output += bias out += bias
# torch.cuda.synchronize() return out
# out = torch.rand((m, n),dtype=torch.bfloat16, device=a.device)
return output def weight_quant_int8(self, weight):
org_w_shape = weight.shape
def quantize_symmetric_per_row_int8(self, x: torch.Tensor): w = weight.to(torch.bfloat16)
""" max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
对输入 x 进行 per-row(dim=1)对称 INT8 量化。 qmin, qmax = -128, 127
scales = (max_val / qmax).float()
Args: w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8)
x: tensor of shape [B, N], dtype in {float32, float16, bfloat16}
assert torch.isnan(scales).sum() == 0
Returns: assert torch.isnan(w_q).sum() == 0
x_q: quantized int8 tensor, shape [B, N]
scales: scale per row, shape [B, 1], same dtype as x scales = scales.view(org_w_shape[0], -1)
""" w_q = w_q.reshape(org_w_shape)
assert x.ndim == 2, f"Expected 2D input, got {x.shape}"
assert x.dtype in [torch.float32, torch.float16, torch.bfloat16] return w_q, scales
# Step 1: 计算每行的最大绝对值 -> shape [B, 1] def per_token_quant_int8_torch(self, input):
max_abs = x.abs().amax(dim=1, keepdim=True) # keepdim=True 保证 shape [32, 1] org_input_shape = input.shape
max_val = input.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
# Step 2: 计算 scale = max_abs / 127 qmin, qmax = -128, 127
# 避免除零:若某行为全零,则 scale=1 scales = max_val / qmax
scales = torch.where( input_q = torch.clamp(torch.round(input / scales), qmin, qmax).to(torch.int8)
max_abs == 0,
torch.tensor(1.0, dtype=x.dtype, device=x.device), assert torch.isnan(scales).sum() == 0
max_abs / 127.0 assert torch.isnan(input_q).sum() == 0
) # shape [32, 1], dtype = x.dtype
return input_q, scales
# Step 3: 量化:x_q = round(x / scales)
# 为避免 bfloat16 精度问题,中间计算用 float32 def forward(self, input):
x_f32 = x.to(torch.float32) #return self.forward_calibration(input)
scales_f32 = scales.to(torch.float32)
x_q_f32 = torch.round(x_f32 / scales_f32) dim = input.dim()
# Step 4: clamp 到 [-127, 127] 并转为 int8
x_q = torch.clamp(x_q_f32, -127, 127).to(torch.int8)
return x_q, scales_f32
def forward(self, input_tensor: torch.Tensor):
# import pdb
# pdb.set_trace()
dim = input_tensor.dim()
if dim > 2: if dim > 2:
input_tensor = input_tensor.squeeze(0) input = input.squeeze(0)
dtype = input_tensor.dtype
# print if self.weight_quant is None:
# import pdb weight_smooth = self.weight * self.scales
# pdb.set_trace() self.scales_rcp = 1.0 / self.scales
input_tensor_quant, input_tensor_scale = per_token_quant_int8(input_tensor) self.weight_quant, self.weight_scale = per_token_quant_int8(weight_smooth)
# input_tensor_quant, input_tensor_scale = self.quantize_symmetric_per_row_int8(input_tensor) del self.weight
output_tensor = self.blaslt_scaled_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.to(torch.float32), dtype, self.bias) input_quant, input_scale = per_token_quant_int8_smooth(input, self.scales_rcp)
# output_sf = self.verify_quant_gemm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.to(torch.float32), dtype, self.bias) output_tensor = self.blaslt_scaled_mm(input_quant, self.weight_quant, input_scale, self.weight_scale, torch.bfloat16, self.bias)
if dim > 2: if dim > 2:
output_tensor = output_tensor.unsqueeze(0) output_tensor = output_tensor.unsqueeze(0)
return output_tensor return output_tensor
def extra_repr(self) -> str: def forward_calibration(self, input):
return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' dim = input.dim()
if dim > 2:
input = input.squeeze(0)
if self.count < 48:
self.calibration(input)
output_tensor = torch.mm(input, self.weight.to(torch.bfloat16).t())
if self.bias is not None:
output_tensor += self.bias.to(torch.bfloat16)
if dim > 2:
output_tensor = output_tensor.unsqueeze(0)
return output_tensor
def calibration(self, input):
self.count += 1
if self.count == 1:
self.weight_max = torch.max(self.weight.to(torch.bfloat16), dim=0)[0].clamp(min=1e-5).cpu()
if self.count <= 48:
tensor = input.abs()
comming_max = torch.max(tensor, dim=0)[0].cpu()
if self.act_scales is not None:
self.act_scales = torch.max(self.act_scales, comming_max)
else:
self.act_scales = comming_max
if self.count == 48:
print(f"====================================={self.count}==========================================")
print(f"weight dtype: {self.weight.dtype} bias : {self.bias.dtype}")
# print("act_max: ",self.act_scales)
# print("weight_max: ",self.weight_max)
self.scales.data = (torch.pow(self.act_scales, self.alpha) / torch.pow(self.weight_max, 1 - self.alpha)).clamp(min=1e-5).cuda()
# print("pow(|act_max|, alpha) / pow(|weight_max|, 1-alpha): ",self.scales)
# print(f"scales min: {self.scales.min().item()}, max: {self.scales.max().item()}")
# print(f"scales has NaN: {torch.any(torch.isnan(self.scales))}")
# print(f"scales has INF: {torch.any(torch.isinf(self.scales))}")
# print(f"scales has zero: {torch.any(self.scales == 0)}")
def fp8_linear(self, input): def fp8_linear(self, input):
dtype = self.weight.dtype dtype = self.weight.dtype
...@@ -636,9 +707,12 @@ if CUBLAS_IS_AVAILABLE: ...@@ -636,9 +707,12 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, int8_optimizations=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if int8_optimizations is not None and int8_optimizations: if model_config is not None and model_config.optimizations.get("int8", False):
return manual_cast_int8_per_channel if model_config.unet_config.get("image_model", "") == "qwen_image":
return manual_cast_int8_smooth
return manual_cast_int8
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None: if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
......
/root/models/ /home/models
\ No newline at end of file \ No newline at end of file
...@@ -912,8 +912,9 @@ class UNETLoader: ...@@ -912,8 +912,9 @@ class UNETLoader:
if weight_dtype == "fp8_e4m3fn": if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e4m3fn_fast": elif weight_dtype == "fp8_e4m3fn_fast":
print("##### PANN_DEBUG UNETLoader fp8_e4m3fn_fast ####")
model_options["dtype"] = torch.float8_e4m3fn model_options["dtype"] = torch.float8_e4m3fn
if unet_name == "Qwen-Image-Edit-2509-smooth-int8.safetensors":
model_options["dtype"] = torch.bfloat16
#model_options["fp8_optimizations"] = True #model_options["fp8_optimizations"] = True
model_options["int8_optimizations"] = True model_options["int8_optimizations"] = True
elif weight_dtype == "fp8_e5m2": elif weight_dtype == "fp8_e5m2":
...@@ -922,7 +923,6 @@ class UNETLoader: ...@@ -922,7 +923,6 @@ class UNETLoader:
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
#model.model = model.model.to(memory_format=torch.channels_last) #model.model = model.model.to(memory_format=torch.channels_last)
#print(model.model)
return (model,) return (model,)
class CLIPLoader: class CLIPLoader:
......
...@@ -462,9 +462,11 @@ def test6(server_url: str): ...@@ -462,9 +462,11 @@ def test6(server_url: str):
image_mapping = {} image_mapping = {}
for weight_dtype in ['default', 'fp8_e4m3fn', 'fp8_e4m3fn_fast']: for weight_dtype in ['default', 'fp8_e4m3fn', 'fp8_e4m3fn_fast']:
#for weight_dtype in ['fp8_e4m3fn_fast']:
logger.info(f'\n========> {workflow_name} {weight_dtype} <========') logger.info(f'\n========> {workflow_name} {weight_dtype} <========')
if weight_dtype == "fp8_e4m3fn_fast":
api_prompt["236"]["inputs"]["unet_name"] = "Qwen-Image-Edit-2509-smooth-int8.safetensors"
recorder = TimingRecorder() recorder = TimingRecorder()
for idx, (image1, image2) in enumerate(test_cases): for idx, (image1, image2) in enumerate(test_cases):
api_prompt["247"]["inputs"]["image"] = image1 api_prompt["247"]["inputs"]["image"] = image1
api_prompt["248"]["inputs"]["image"] = image2 api_prompt["248"]["inputs"]["image"] = image2
...@@ -613,7 +615,7 @@ if __name__ == "__main__": ...@@ -613,7 +615,7 @@ if __name__ == "__main__":
#test6(server_url) #test6(server_url)
# Test old photo restoration workflow # Test old photo restoration workflow
#test7(server_url) test7(server_url)
#test7(server_url) #test7(server_url)
#test7(server_url) #test7(server_url)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment