Commit 3aa95081 authored by helloyongyang's avatar helloyongyang
Browse files

[Feature]: support many quant kernels

parent 9a686a73
import torch
from abc import ABCMeta, abstractmethod
from vllm import _custom_ops as ops
import sgl_kernel
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
......@@ -9,6 +10,11 @@ try:
except ImportError:
Q8F = None
try:
import deep_gemm
except ImportError:
deep_gemm = None
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name):
......@@ -70,8 +76,102 @@ class MMWeightForceFP32(MMWeight):
self.bias = self.bias.to(torch.float32)
class MMWeightQuantTemplate(MMWeightTemplate):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
self.load_func = None
self.weight_need_transpose = True
self.act_quant_func = None
"""
weight load functions
"""
def load(self, weight_dict):
self.load_func(weight_dict)
if self.weight_need_transpose:
self.weight = self.weight.t()
def load_quantized(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
def load_fp8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32)
else:
self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def load_int8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].to(torch.float32)
w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32)
else:
self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def load_fp8_perblock128_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].cuda()
self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
else:
self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def per_block_cast_to_fp8(self, x):
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
"""
act quant kernels
"""
def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannel_sym_sgl(self, x):
m, k = x.shape
input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
return input_tensor_quant, input_tensor_scale
def act_quant_int8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
m, k = x.shape
input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
sgl_kernel.sgl_per_token_group_quant_fp8(x, input_tensor_quant, input_tensor_scale, group_size=128, eps=1e-10, fp8_min=-448.0, fp8_max=448.0)
return input_tensor_quant, input_tensor_scale
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
......@@ -83,31 +183,23 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn).t().cuda()
self.weight_scale = self.weight_scale.to(torch.float32).cuda()
else:
self.weight = weight_dict[self.weight_name].t().cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
qinput, x_scale = ops.scaled_fp8_quant(input_tensor, None, scale_ub=None, use_per_token_if_dynamic=True)
torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(output_tensor, input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
......@@ -119,31 +211,46 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8).t().cuda()
self.weight_scale = self.weight_scale.to(torch.float32).cuda()
else:
self.weight = weight_dict[self.weight_name].t().cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
qinput, x_scale, _ = ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True)
torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(output_tensor, input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Q8F
"""
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = Q8F.linear.fp8_linear(input_tensor_quant, self.weight, self.bias, input_tensor_scale, self.weight_scale, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0)
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
......@@ -155,55 +262,164 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
def load(self, weight_dict):
if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].cuda()
w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32)
else:
self.weight = weight_dict[self.weight_name].cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
self.bias = weight_dict[self.bias_name].float().cuda() if self.bias_name is not None else None
def apply(self, input_tensor, act=None):
qinput, x_scale, _ = ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True)
output_tensor = Q8F.linear.q8_linear(qinput, self.weight, self.bias, x_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16)
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = Q8F.linear.q8_linear(input_tensor_quant, self.weight, self.bias, input_tensor_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0)
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate):
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm
Quant MM:
Weight: fp8 perblock 128x128 sym
Act: fp8 perchannel-pergroup group=128 dynamic sym
Kernel: Deepgemm
Reference: https://github.com/deepseek-ai/DeepGEMM
Example:
Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)
Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
Act Scale: torch.Size([1024, 16]), torch.float32
Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
Weight Scale: torch.Size([32, 16]), torch.float32
Out : torch.Size([1024, 4096]), torch.bfloat16
"""
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
self.load_func = self.load_fp8_perblock128_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_deepgemm
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
deep_gemm.gemm_fp8_fp8_bf16_nt((input_tensor_quant, input_tensor_scale), (self.weight, self.weight_scale), output_tensor)
if self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl
Quant MM:
Weight: fp8 perblock 128x128 sym
Act: fp8 pertoken-pergroup group=128 dynamic sym
Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
"""
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
self.load_func = self.load_fp8_perblock128_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
deep_gemm.gemm_fp8_fp8_bf16_nt((input_tensor_quant, input_tensor_scale), (self.weight, self.weight_scale), output_tensor)
if self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl")
class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Q8F
Kernel: quant-mm using vllm, act dynamic quant using Sgl-kernel
"""
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
def load(self, weight_dict):
if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32)
else:
self.weight = weight_dict[self.weight_name].cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
self.bias = weight_dict[self.bias_name].float().cuda() if self.bias_name is not None else None
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(output_tensor, input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Sgl-kernel
"""
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
def apply(self, input_tensor):
qinput, x_scale = ops.scaled_fp8_quant(input_tensor, None, scale_ub=None, use_per_token_if_dynamic=True)
output_tensor = Q8F.linear.fp8_linear(qinput, self.weight, self.bias, x_scale, self.weight_scale, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.fp8_scaled_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, torch.bfloat16, bias=self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWint8channelAint8channeldynamicActVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.int8_scaled_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, torch.bfloat16, self.bias)
return output_tensor
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment