Commit 9dd70f0e authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA to use fused rmsnorm + contiguous +...

add VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA to use fused rmsnorm + contiguous + rope(for dpsk-v3) + concat_and_cache_mla + q quant, control bmm(todo) + cat +mla (fp8)
parent 680ee839
...@@ -204,6 +204,8 @@ class Attention(nn.Module): ...@@ -204,6 +204,8 @@ class Attention(nn.Module):
# shape does not match the query shape, so we optionally let the model # shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape. # definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None, output_shape: Optional[torch.Size] = None,
query_nope: Optional[torch.Size] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
...@@ -270,7 +272,7 @@ class Attention(nn.Module): ...@@ -270,7 +272,7 @@ class Attention(nn.Module):
query, key, value, output, self.layer_name) query, key, value, output, self.layer_name)
else: else:
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name, None, q_ori, key_normed, positions, weight, cos_sin_cache) query, key, value, output, self.layer_name, None, query_nope, num_local_heads, q_ori, key_normed, positions, weight, cos_sin_cache)
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
else: else:
if self.use_direct_call: if self.use_direct_call:
...@@ -511,6 +513,8 @@ def unified_attention_with_output( ...@@ -511,6 +513,8 @@ def unified_attention_with_output(
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
...@@ -542,6 +546,8 @@ def unified_attention_with_output( ...@@ -542,6 +546,8 @@ def unified_attention_with_output(
attn_metadata, attn_metadata,
output=output, output=output,
output_scale=output_scale, output_scale=output_scale,
query_nope=query_nope,
num_local_heads=num_local_heads,
q_ori=q_ori, q_ori=q_ori,
key_normed=key_normed, key_normed=key_normed,
positions=positions, positions=positions,
......
...@@ -277,6 +277,60 @@ def flash_mla_with_kvcache_fp8( ...@@ -277,6 +277,60 @@ def flash_mla_with_kvcache_fp8(
) )
return out, softmax_lse return out, softmax_lse
def flash_mla_with_kvcache_fp8_with_cat(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
# #
# TODO: Add fake functions # TODO: Add fake functions
# #
......
...@@ -198,6 +198,7 @@ if TYPE_CHECKING: ...@@ -198,6 +198,7 @@ if TYPE_CHECKING:
VLLM_PP_DEBUG: bool = False VLLM_PP_DEBUG: bool = False
VLLM_USE_V32_ENCODE: bool = False VLLM_USE_V32_ENCODE: bool = False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
...@@ -1301,6 +1302,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1301,6 +1302,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv('VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT', 'False').lower() in lambda: (os.getenv('VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT', 'False').lower() in
("true", "1")), ("true", "1")),
# vllm will use fused rmsnorm + contiguous + rope(for dpsk-v3) + concat_and_cache_mla + q quant, control bmm + cat +mla (fp8)
"VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA":
lambda: (os.getenv('VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA', 'False').lower() in
("true", "1")),
# vLLM will use fused RMS + RoPE kernel # vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE": "VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in
......
...@@ -260,6 +260,8 @@ def get_model_architecture( ...@@ -260,6 +260,8 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
# if not envs.is_set("VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA"):
# os.environ['VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1' os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
...@@ -303,6 +305,8 @@ def get_model_architecture( ...@@ -303,6 +305,8 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
# if not envs.is_set("VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA"):
# os.environ['VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1' os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
......
...@@ -838,6 +838,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -838,6 +838,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
...@@ -886,6 +888,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -886,6 +888,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
...@@ -945,6 +949,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -945,6 +949,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim), self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
......
...@@ -217,7 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -217,7 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_qkv
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -1233,21 +1233,62 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1233,21 +1233,62 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16" kv_cache_dtype_str = "bf16"
else: else:
kv_cache_dtype_str = self.kv_cache_dtype kv_cache_dtype_str = self.kv_cache_dtype
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...], if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
q, if has_prefill:
k_pe.squeeze(1), fused_rms_norm_rope_contiguous(
k_c_normed, # not normed positions[:num_actual_toks, ...],
key_normed[:num_actual_toks, ...], # normed q,
weight, k_pe.squeeze(1),
cos_sin_cache, k_c_normed, # not normed
attn_metadata.slot_mapping.flatten(), key_normed[:num_actual_toks, ...], # normed
kv_cache, weight,
kv_cache_dtype_str, cos_sin_cache,
1.0, attn_metadata.slot_mapping.flatten(),
False, kv_cache,
1e-6, kv_cache_dtype_str,
) 1.0,
False,
1e-6,
)
else:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=kv_cache_dtype_str, device=q.device)
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device)
fuse_rmsnorm_rope_quant_qkv(
positions[:num_actual_toks, ...],
query_nope,
q,
q_quant,
q_scale,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if has_prefill: if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
...@@ -1259,12 +1300,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1259,12 +1300,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_decode: if has_decode:
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
decode_q = q_quant[:num_decode_tokens]
decode_q_nope, decode_q_pe = decode_q.split( decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P) # Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1) decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L) # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) # todo: bmm support
decode_ql_nope = torch.bmm(q_scale, decode_q_nope, self.W_UK_T) if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA else torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L) # Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
......
...@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, ...@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe, flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8, flash_mla_with_kvcache_fp8,
flash_mla_with_kvcache_fp8_with_cat,
get_mla_decoding_metadata_dense_fp8, get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -181,31 +182,48 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -181,31 +182,48 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8: if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
if q_nope.shape[0] < 1024: o, _ = flash_mla_with_kvcache_fp8_with_cat(
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode q_nope=q_nope.unsqueeze(1),
q = concat_helper_decode(q_nope, q_pe, dim=2)\ q_pe=q_pe.unsqueeze(1),
.unsqueeze(1) k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q = q_scale,
descale_k = k_scale,
)
else:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else: else:
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
else: o, _ = flash_mla_with_kvcache_fp8(
q = torch.cat([q_nope, q_pe], dim=-1)\ q=q.to(torch.float8_e4m3fn),
.unsqueeze(1) # Add seqlen dim of 1 (decode) k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
o, _ = flash_mla_with_kvcache_fp8( block_table=attn_metadata.decode.block_table,
q=q.to(torch.float8_e4m3fn), cache_seqlens=attn_metadata.decode.seq_lens,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1 head_dim_v=self.kv_lora_rank,
block_table=attn_metadata.decode.block_table, tile_scheduler_metadata=attn_metadata.decode.
cache_seqlens=attn_metadata.decode.seq_lens, tile_scheduler_metadata,
head_dim_v=self.kv_lora_rank, num_splits=attn_metadata.decode.num_splits,
tile_scheduler_metadata=attn_metadata.decode. softmax_scale=self.scale,
tile_scheduler_metadata, causal=True,
num_splits=attn_metadata.decode.num_splits, descale_q=q_scale,
softmax_scale=self.scale, descale_k=k_scale,
causal=True, )
descale_q=q_scale,
descale_k=k_scale,
)
else: else:
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3": if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment