Commit a1f5ce6e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-bmm-2cd-0206-rebase-newest' into 'v0.9.2-dev'

perf: mla后面的DTbmm融合

See merge request dcutoolkit/deeplearing/vllm!415
parents 9e59081f 8be144d5
...@@ -218,6 +218,7 @@ if TYPE_CHECKING: ...@@ -218,6 +218,7 @@ if TYPE_CHECKING:
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1 VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_USE_FUSED_DTBMM: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1398,6 +1399,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1398,6 +1399,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MOE_W16A16_TRITON": "VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")), ("true", "1")),
# Only quantized DeepSeek models supported.
"VLLM_USE_FUSED_DTBMM":
lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -218,6 +218,8 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -218,6 +218,8 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_qkv from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_qkv
from lmslim.layers.gemm.fp8_utils import per_token_quant_fp8
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -871,13 +873,74 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -871,13 +873,74 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return attn_out, lse return attn_out, lse
return attn_out return attn_out
def weight_quant_fp8(self, weight, dim:Optional[int]=1):
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_min = finfo.min
fp8_max = finfo.max
absmax = torch.max(weight.abs(), dim=dim, keepdim=True).values
absmax = absmax.clamp(min=1e-10)
scale = absmax.to(torch.float32) / fp8_max
scale = scale.clamp(min=1e-10)
weight_fp32 = weight.float() if weight.dtype != torch.float32 else weight
scale_fp32 = scale.float() if scale.dtype != torch.float32 else scale
weight_q = weight_fp32 / scale_fp32
weight_q = weight_q.clamp(fp8_min, fp8_max)
weight_q = weight_q.to(torch.float8_e4m3fn)
return weight_q, scale
def _v_up_proj(self, x): def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L) if self.enable_fused_DTBmm():
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).contiguous()
# Multiply (N, B, L) x (N, L, V) -> (N, B, V) B, N, L = x.shape
x = torch.bmm(x, self.W_UV) N, V, L = self.weight_uv_bmm.shape
# Convert from (N, B, V) to (B, N * V) if B <= 32:
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) from lightop import fused_bmm as fused_DTBmm
from lightop import get_batched_gemm_w8a8_config as DTBmm_config
x = x.reshape(-1, self.num_heads, self.kv_lora_rank).contiguous()
x_q, x_scale = per_token_quant_fp8(x)
x_out = torch.empty(B, N, V, dtype=torch.bfloat16, device=x.device)
_dtype = torch.bfloat16
_config, _status = DTBmm_config(B, N, L)
assert x_q.shape == (B, N, L) , f"assert error {x_q.shape}"
assert x_scale.shape == (B, N, 1) , f"assert error {x_scale.shape}"
fused_DTBmm(x=x_q, w=self.weight_uv_bmm, x_scale=x_scale, w_scale=self.weight_uv_scale_bmm,
bias=None, dtype=_dtype, output=x_out,
transpose_bm=False, transpose_bm_in=False, config=_config)
out = x_out.reshape(-1, self.num_heads * self.v_head_dim)
return out
else:
from lmslim import quant_ops
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
x_= x.reshape(-1,self.kv_lora_rank).contiguous()
x_q, x_scale = per_token_quant_fp8(x_)
x_q = x_q.reshape(self.num_heads,-1,self.kv_lora_rank).contiguous()
x_scale = x_scale.reshape(self.num_heads,-1).contiguous()
weight_k = self.W_UV.shape[1]
weight_n = self.W_UV.shape[2]
_, result = quant_ops.hipblaslt_w8a8_channelwise_gemm(
x_q, self.weight_uv_bmm , x_scale, self.weight_uv_scale_bmm,
x.shape[1], weight_n, weight_k, 'NT', torch.bfloat16, None)
return result.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
else: # default
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
def enable_fused_DTBmm(self):
if envs.VLLM_USE_FUSED_DTBMM and \
self.kv_cache_dtype == "fp8_e4m3" and \
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return True
else:
return False
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
...@@ -932,6 +995,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -932,6 +995,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.W_UV = W_UV.transpose(0, 1) self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L) # Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0) self.W_UK_T = W_UK.permute(1, 2, 0)
if self.enable_fused_DTBmm():
weight_uv_NLV, weight_uv_scale_NL =self.weight_quant_fp8(self.W_UV, 1)
self.weight_uv_bmm = weight_uv_NLV.transpose(1,2).contiguous()
self.weight_uv_scale_bmm = weight_uv_scale_NL.transpose(1,2).contiguous()
def _compute_prefill_context( def _compute_prefill_context(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment