"tests/L0/vscode:/vscode.git/clone" did not exist on "cc5f83b5cda1fd17bf828097e993e47d63a55a4b"
Commit 6c18f54c authored by helloyongyang's avatar helloyongyang
Browse files

fix bug & add quant kernel

parent 3aa95081
...@@ -4,7 +4,7 @@ except ImportError: ...@@ -4,7 +4,7 @@ except ImportError:
flash_attn_varlen_func = None flash_attn_varlen_func = None
def flash_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None): def flash_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
x = flash_attn_varlen_func( x = flash_attn_varlen_func(
q, q,
k, k,
......
...@@ -4,7 +4,7 @@ except ImportError: ...@@ -4,7 +4,7 @@ except ImportError:
flash_attn_varlen_func_v3 = None flash_attn_varlen_func_v3 = None
def flash_attn3(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None): def flash_attn3(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
x = flash_attn_varlen_func_v3( x = flash_attn_varlen_func_v3(
q, q,
k, k,
......
...@@ -371,6 +371,29 @@ class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate): ...@@ -371,6 +371,29 @@ class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.fp8_scaled_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, torch.bfloat16, bias=self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl") @MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate): class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
""" """
...@@ -395,7 +418,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate): ...@@ -395,7 +418,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm") @MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWint8channelAint8channeldynamicActVllm(MMWeightQuantTemplate): class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
""" """
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment