Commit 6c18f54c authored by helloyongyang's avatar helloyongyang
Browse files

fix bug & add quant kernel

parent 3aa95081
......@@ -4,7 +4,7 @@ except ImportError:
flash_attn_varlen_func = None
def flash_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None):
def flash_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
x = flash_attn_varlen_func(
q,
k,
......
......@@ -4,7 +4,7 @@ except ImportError:
flash_attn_varlen_func_v3 = None
def flash_attn3(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None):
def flash_attn3(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
x = flash_attn_varlen_func_v3(
q,
k,
......
......@@ -371,6 +371,29 @@ class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.fp8_scaled_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, torch.bfloat16, bias=self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
"""
......@@ -395,7 +418,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWint8channelAint8channeldynamicActVllm(MMWeightQuantTemplate):
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment