Commit 347fc09c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-nmz' into v0.9.2-dev

parents ffcc47b7 3e191138
......@@ -16,7 +16,10 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonState)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported)
from vllm import envs
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
......@@ -93,7 +96,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.num_q_heads,
1, # MQA for the decode path
)
return m
......@@ -222,6 +224,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
q_scale = None,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor:
......@@ -233,6 +236,21 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2), # Add head dim of 1
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
else:
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
......@@ -246,5 +264,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
return self._v_up_proj(o)
......@@ -1404,6 +1404,5 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[num_prefill_tokens:] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype)
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._q_scale, layer._k_scale, self.kv_cache_dtype)
return output
\ No newline at end of file
......@@ -69,6 +69,27 @@ def get_mla_metadata(
num_heads_k)
def get_mla_decoding_metadata_dense_fp8(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Return:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens,
num_heads_per_head_k,
num_heads_k)
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
......@@ -199,6 +220,59 @@ def flash_mla_with_kvcache_q_nope_pe(
return out, softmax_lse
def flash_mla_with_kvcache_fp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
torch.int32, return by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, return by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
#
# TODO: Add fake functions
#
......
......@@ -146,6 +146,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_FLASH_MLA_FP8: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
......@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
# If set, vLLM will use FLASH MLA fp8 attention optimizations.
"VLLM_USE_FLASH_MLA_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA_FP8", "0"))),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
......
......@@ -255,8 +255,8 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
......@@ -300,8 +300,8 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
......
This diff is collapsed.
......@@ -1095,6 +1095,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata: M,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
......@@ -1154,7 +1156,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale,
)
else:
from lightop import fused_rms_norm_rope_contiguous
if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16"
......@@ -1162,7 +1163,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
from lightop import fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
......@@ -1199,6 +1200,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype)
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._q_scale, layer._k_scale, self.kv_cache_dtype)
return output_padded
\ No newline at end of file
......@@ -11,6 +11,8 @@ from vllm.attention.backends.abstract import (AttentionType,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata,
flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
......@@ -162,12 +164,42 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
q_scale = None,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
else:
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment