Commit 347fc09c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-nmz' into v0.9.2-dev

parents ffcc47b7 3e191138
...@@ -16,7 +16,10 @@ from vllm.attention.backends.mla.common import (MLACommonBackend, ...@@ -16,7 +16,10 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonState) MLACommonState)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm import envs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
...@@ -93,7 +96,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -93,7 +96,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.num_q_heads, self.num_q_heads,
1, # MQA for the decode path 1, # MQA for the decode path
) )
return m return m
...@@ -222,6 +224,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -222,6 +224,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
q_scale = None,
k_scale = None, k_scale = None,
kv_cache_dtype = "auto", kv_cache_dtype = "auto",
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -233,18 +236,32 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -233,18 +236,32 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache( if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
q=q, o, _ = flash_mla_with_kvcache_fp8(
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 q=q.to(torch.float8_e4m3fn),
block_table=decode_meta.block_tables, k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2), # Add head dim of 1
cache_seqlens=decode_meta.seq_lens_tensor, block_table=decode_meta.block_tables,
head_dim_v=self.kv_lora_rank, cache_seqlens=decode_meta.seq_lens_tensor,
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, head_dim_v=self.kv_lora_rank,
num_splits=decode_meta.decode_num_splits, tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
softmax_scale=self.scale, num_splits=decode_meta.decode_num_splits,
causal=True, softmax_scale=self.scale,
k_scale = k_scale, causal=True,
kv_cache_dtype = kv_cache_dtype, descale_q=q_scale,
) descale_k=k_scale,
)
else:
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
return self._v_up_proj(o) return self._v_up_proj(o)
...@@ -1404,6 +1404,5 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1404,6 +1404,5 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[num_prefill_tokens:] = self._forward_decode( output[num_prefill_tokens:] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._q_scale, layer._k_scale, self.kv_cache_dtype)
return output return output
\ No newline at end of file
...@@ -69,6 +69,27 @@ def get_mla_metadata( ...@@ -69,6 +69,27 @@ def get_mla_metadata(
num_heads_k) num_heads_k)
def get_mla_decoding_metadata_dense_fp8(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Return:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens,
num_heads_per_head_k,
num_heads_k)
def flash_mla_with_kvcache( def flash_mla_with_kvcache(
q: torch.Tensor, q: torch.Tensor,
k_cache: torch.Tensor, k_cache: torch.Tensor,
...@@ -199,6 +220,59 @@ def flash_mla_with_kvcache_q_nope_pe( ...@@ -199,6 +220,59 @@ def flash_mla_with_kvcache_q_nope_pe(
return out, softmax_lse return out, softmax_lse
def flash_mla_with_kvcache_fp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
torch.int32, return by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, return by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
# #
# TODO: Add fake functions # TODO: Add fake functions
# #
......
...@@ -146,6 +146,7 @@ if TYPE_CHECKING: ...@@ -146,6 +146,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_MLA: bool = False VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_FLASH_MLA_FP8: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
...@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_MLA": "VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))), lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
# If set, vLLM will use FLASH MLA fp8 attention optimizations.
"VLLM_USE_FLASH_MLA_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA_FP8", "0"))),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP": "VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
......
...@@ -255,8 +255,8 @@ def get_model_architecture( ...@@ -255,8 +255,8 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"): if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1' os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1' os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
...@@ -300,8 +300,8 @@ def get_model_architecture( ...@@ -300,8 +300,8 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"): if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1' os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1' os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
......
This diff is collapsed.
...@@ -1095,6 +1095,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1095,6 +1095,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata: M, attn_metadata: M,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None, q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None, key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
...@@ -1154,7 +1156,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1154,7 +1156,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale, scale=layer._k_scale,
) )
else: else:
from lightop import fused_rms_norm_rope_contiguous
if self.kv_cache_dtype == "auto": if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16: if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16" kv_cache_dtype_str = "fp16"
...@@ -1162,7 +1163,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1162,7 +1163,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16" kv_cache_dtype_str = "bf16"
else: else:
kv_cache_dtype_str = self.kv_cache_dtype kv_cache_dtype_str = self.kv_cache_dtype
from lightop import fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous( fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...], positions[:num_actual_toks, ...],
q, q,
...@@ -1199,6 +1200,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1199,6 +1200,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[:num_decode_tokens] = self._forward_decode( output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._q_scale, layer._k_scale, self.kv_cache_dtype)
return output_padded return output_padded
\ No newline at end of file
...@@ -11,6 +11,8 @@ from vllm.attention.backends.abstract import (AttentionType, ...@@ -11,6 +11,8 @@ from vllm.attention.backends.abstract import (AttentionType,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe, flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend, from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
...@@ -162,13 +164,14 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -162,13 +164,14 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
q_scale = None,
k_scale = None, k_scale = None,
kv_cache_dtype = "auto", kv_cache_dtype = "auto",
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3": if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024: if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
...@@ -180,11 +183,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -180,11 +183,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
else: else:
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3": o, _ = flash_mla_with_kvcache_fp8(
o, _ = flash_mla_with_kvcache( q=q.to(torch.float8_e4m3fn),
q=q, k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2), # Add head dim of 1
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table, block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
...@@ -193,24 +195,54 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -193,24 +195,54 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=attn_metadata.decode.num_splits, num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
k_scale = k_scale, descale_q=q_scale,
kv_cache_dtype = kv_cache_dtype, descale_k=k_scale,
) )
else: else:
o, _ = flash_mla_with_kvcache_q_nope_pe( if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
q_nope=q_nope.unsqueeze(1), if envs.VLLM_USE_OPT_CAT:
q_pe=q_pe.unsqueeze(1), if q_nope.shape[0] < 1024:
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
block_table=attn_metadata.decode.block_table, q = concat_helper_decode(q_nope, q_pe, dim=2)\
cache_seqlens=attn_metadata.decode.seq_lens, .unsqueeze(1)
head_dim_v=self.kv_lora_rank, else:
tile_scheduler_metadata=attn_metadata.decode. q = torch.cat([q_nope, q_pe], dim=-1)\
tile_scheduler_metadata, .unsqueeze(1) # Add seqlen dim of 1 (decode)
num_splits=attn_metadata.decode.num_splits, else:
softmax_scale=self.scale, q = torch.cat([q_nope, q_pe], dim=-1)\
causal=True, .unsqueeze(1) # Add seqlen dim of 1 (decode)
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype, if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
) o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
else:
o, _ = flash_mla_with_kvcache_q_nope_pe(
q_nope=q_nope.unsqueeze(1),
q_pe=q_pe.unsqueeze(1),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
return self._v_up_proj(o) return self._v_up_proj(o)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment