Commit 6ca1362b authored by wujl5's avatar wujl5 Committed by wangmin6
Browse files

perf: DS v2增加DTBMM融合,默认关闭

parent 3824b261
......@@ -304,6 +304,7 @@ if TYPE_CHECKING:
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE: bool = False
VLLM_USE_FUSED_DTBMM: bool = False # DOUBLE TRANS BMM FP8
def get_default_cache_root():
......@@ -1910,6 +1911,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE":
lambda: (os.environ.get("VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE", "False").lower() in
("true", "1")),
# DOUBLE TRANSPOSE BMM FP8 format use in NMZ DeepSeek models
"VLLM_USE_FUSED_DTBMM":
lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -1339,9 +1339,18 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
if self.enable_fused_DTBmm():
weight_uv_NLV, weight_uv_scale_NL =self.weight_quant_fp8(self.W_UV, 1)
self.weight_uv_bmm = weight_uv_NLV.transpose(1,2).contiguous()
self.weight_uv_scale_bmm = weight_uv_scale_NL.transpose(1,2).contiguous()
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
batch_size = x.shape[:-2].numel()
if self.enable_fused_DTBmm() and batch_size <= 32:
pass
else:
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
out = out.view(-1, self.num_heads, self.v_head_dim)
if self.is_aiter_triton_fp4_bmm_enabled:
out = rocm_aiter_ops.batched_gemm_a16wfp4(
......@@ -1360,19 +1369,66 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
# Convert from (N, B, V) to (B, N * V)
out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
# Adjust output buffer shape back to the original (B, N * V)
N, B, V = out.shape
out.resize_((B, N * V))
out.copy_(out_new) # Copy result
# DOUBLE TRANS BMM FP8
if self.enable_fused_DTBmm() and batch_size <= 32:
out = out.transpose(0, 1)
x = x.view(-1, self.num_heads, self.kv_lora_rank).contiguous()
B, N, L = x.shape
N, V, L = self.weight_uv_bmm.shape
from lmslim.layers.gemm.fp8_utils import per_token_quant_fp8
from lightop import fused_bmm as fused_DTBmm
from lightop import get_batched_gemm_w8a8_config as DTBmm_config
x = x.reshape(-1, self.num_heads, self.kv_lora_rank).contiguous()
x_q, x_scale = per_token_quant_fp8(x)
x_out = torch.empty(B, N, V, dtype=torch.bfloat16, device=x.device)
_dtype = torch.bfloat16
_config, _status = DTBmm_config(B, N, L)
assert x_q.shape == (B, N, L) , f"assert error {x_q.shape}"
assert x_scale.shape == (B, N, 1) , f"assert error {x_scale.shape}"
fused_DTBmm(x=x_q, w=self.weight_uv_bmm, x_scale=x_scale, w_scale=self.weight_uv_scale_bmm,
bias=None, dtype=_dtype, output=x_out,
transpose_bm=False, transpose_bm_in=False, config=_config)
out_new = x_out.reshape(-1, self.num_heads * self.v_head_dim) # B, N*V
out.resize_((B, N * V))
out.copy_(out_new)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
# Convert from (N, B, V) to (B, N * V)
out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
# Adjust output buffer shape back to the original (B, N * V)
N, B, V = out.shape
out.resize_((B, N * V))
out.copy_(out_new) # Copy result
def enable_fused_DTBmm(self): # DOUBLE TRANS BMM FP8
if envs.VLLM_USE_FUSED_DTBMM and \
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return True
else:
return False
def weight_quant_fp8(self,
weight,
dim:int=1):
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_min = finfo.min
fp8_max = finfo.max
absmax = torch.max(weight.abs(), dim=dim, keepdim=True).values
absmax = absmax.clamp(min=1e-10)
scale = absmax.to(torch.float32) / fp8_max
scale = scale.clamp(min=1e-10)
weight_fp32 = weight.float() if weight.dtype != torch.float32 else weight
scale_fp32 = scale.float() if scale.dtype != torch.float32 else scale
weight_q = weight_fp32 / scale_fp32
weight_q = weight_q.clamp(fp8_min, fp8_max)
weight_q = weight_q.to(torch.float8_e4m3fn)
return weight_q, scale
class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment