Commit bac201c9 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix: 修复ep的变量未定义

set VLLM_USE_FUSED_QA_KVA_GEMM=1
feat:w4a8Linear调用apply_int8_linear,以支持blaslt
parent ffd123f6
...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt5.' + sha[:7] version = 'das.opt6.' + sha[:7]
else: else:
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt5' version = 'das.opt6'
# dtk version # dtk version
......
...@@ -1365,7 +1365,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1365,7 +1365,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Only quantized DeepSeek models supported. # Only quantized DeepSeek models supported.
# Unquantized versions are not supported. # Unquantized versions are not supported.
"VLLM_USE_FUSED_QA_KVA_GEMM": "VLLM_USE_FUSED_QA_KVA_GEMM":
lambda: (os.environ.get("VLLM_USE_FUSED_QA_KVA_GEMM", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_QA_KVA_GEMM", "True").lower() in
("true", "1")), ("true", "1")),
"VLLM_ZERO_OVERHEAD_ENHANCE": "VLLM_ZERO_OVERHEAD_ENHANCE":
lambda: (os.getenv('VLLM_ZERO_OVERHEAD_ENHANCE', '0').lower() in lambda: (os.getenv('VLLM_ZERO_OVERHEAD_ENHANCE', '0').lower() in
......
...@@ -5,6 +5,7 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -5,6 +5,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase) from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import apply_int8_linear
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
...@@ -111,7 +112,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -111,7 +112,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
for key, value in configs_dict.items(): for key, value in configs_dict.items():
m=int(key.split('_')[0]) m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value) ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
else: elif self.w8a8_strategy==3:
layer.weight.data = layer.weight.data.T
else:
weight_data=layer.weight.data weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1) _weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight layer.weight.data=_weight
...@@ -158,81 +161,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -158,81 +161,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
input_quant_args: Optional[list[torch.Tensor]] = None, input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None, **_ silu_quant_args: Optional[list[torch.Tensor]] = None, **_
): ):
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: return apply_int8_linear(input=x,
assert len(input_quant_args) == 2 weight=layer.weight,
x_q, x_scale = input_quant_args weight_scale=layer.weight_scale,
elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None: bias=bias,
assert len(silu_quant_args) == 2 w8a8_strategy=self.w8a8_strategy,
x_q, x_scale = silu_quant_args input_quant_args=input_quant_args,
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and input_quant_args is not None: silu_quant_args=silu_quant_args)
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_q, x_scale = silu_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1:
m=x_q.shape[0]
k=x_q.shape[1]
n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
elif m<=64:
m_ = ((m + 3) // 4) * 4 #取值到最近的4的倍数
elif m<=160:
m_ = (m // 8) * 8
elif m<200: #256
m_=160
elif m<480: #512
m_=256
elif m<960: #1024
m_=512
elif m<2048:
m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else:
best_config=None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return ops.triton_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,best_config=best_config)
elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
class SlimQuantW4A8Int8MoEMethod: class SlimQuantW4A8Int8MoEMethod:
......
...@@ -288,6 +288,9 @@ def get_model_architecture( ...@@ -288,6 +288,9 @@ def get_model_architecture(
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1' os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"): if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"):
os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1' os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1'
if architectures in [['Qwen3ForCausalLM']]:
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]: if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"): if not envs.is_set("VLLM_USE_V32_ENCODE"):
...@@ -336,6 +339,9 @@ def get_model_architecture( ...@@ -336,6 +339,9 @@ def get_model_architecture(
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1' os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"): if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"):
os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1' os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1'
if architectures in [['Qwen3ForCausalLM']]:
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]: if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"): if not envs.is_set("VLLM_USE_V32_ENCODE"):
......
...@@ -422,9 +422,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -422,9 +422,6 @@ class DeepseekV2MoE(nn.Module):
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant # fp16 mode not fused quant
if i_q is not None:
i_q=iqis[0]
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=i_q, i_s=i_s) i_q=i_q, i_s=i_s)
...@@ -469,9 +466,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -469,9 +466,8 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None assert shared_output is not None
final_hidden_states += (shared_output * (1. / self.routed_scaling_factor)) final_hidden_states += (shared_output * (1. / self.routed_scaling_factor))
else: else:
if i_q is not None: if iqis is not None:
i_q=iqis[0] i_q, i_s = iqis
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=i_q, i_s=i_s) i_q=i_q, i_s=i_s)
......
...@@ -577,7 +577,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -577,7 +577,7 @@ class FlashAttentionImpl(AttentionImpl):
layer._v_scale, layer._v_scale,
) )
else: else:
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16: if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
from lightop import reshape_and_cache_cuda from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda( reshape_and_cache_cuda(
key, value, key, value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment