Commit d7bee8b6 authored by wanghl6's avatar wanghl6
Browse files

feat: 元宝 prefill融合算子优化

parent d761561a
......@@ -2184,6 +2184,7 @@ def gather_cache(src_cache: torch.Tensor,
) -> None:
#支持"kv cache fp8" 临时方案,带dtype的gather_cache在vllm0.10后会实现。
if kv_dtype == "fp8" or kv_dtype == "fp8_e5m2" or kv_dtype == "fp8_e4m3":
if not envs.VLLM_FUSED_GATHER_CACHE_CONVERT_FP8:
dst_fp8 = torch.empty(dst.shape, dtype=torch.uint8, device=dst.device)
#convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table,
......@@ -2191,6 +2192,18 @@ def gather_cache(src_cache: torch.Tensor,
#dst_fp8->bf16
# convert_fp8(dst, dst_fp8, scale, kv_dtype)
convert_fp8(dst, dst_fp8, 1.0, kv_dtype)
else:
from lightop import op
op.gather_convert_fp8_cache(
src_cache,
dst,
block_table,
cu_seq_lens,
batch_size,
scale,
kv_dtype,
seq_starts
)
else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
......
......@@ -219,7 +219,9 @@ if TYPE_CHECKING:
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_USE_FUSED_DTBMM: bool = False
VLLM_FUSE_CAT_AND_CAST_FP8: bool = False
VLLM_FUSED_GATHER_CACHE_CONVERT_FP8: bool = False
VLLM_FUSED_RN_ROPE_INT8_QUANT: bool = False
def get_default_cache_root():
return os.getenv(
"XDG_CACHE_HOME",
......@@ -1404,6 +1406,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_DTBMM":
lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in
("true", "1")),
"VLLM_FUSE_CAT_AND_CAST_FP8":
lambda: (os.environ.get("VLLM_FUSE_CAT_AND_CAST_FP8", "False").lower() in
("true", "1")),
"VLLM_FUSED_GATHER_CACHE_CONVERT_FP8":
lambda: (os.environ.get("VLLM_FUSED_GATHER_CACHE_CONVERT_FP8", "False").lower() in
("true", "1")),
"VLLM_FUSED_RN_ROPE_INT8_QUANT":
lambda: (os.environ.get("VLLM_FUSED_RN_ROPE_INT8_QUANT", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -469,7 +469,7 @@ def apply_int8_linear(
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
if (envs.USE_FUSED_RMS_QUANT or envs.VLLM_FUSED_RN_ROPE_INT8_QUANT) and input_quant_args is not None:
assert len(input_quant_args) == 2
x_zp =None
x_q, x_scale = input_quant_args
......
......@@ -1015,6 +1015,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
use_flash_fp8_arch = (
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938"
and envs.VLLM_USE_FLASH_ATTN_FP8
)
use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
......@@ -1029,62 +1035,64 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=kv_scale,
)
kv_c_normed = workspace[:toks]\
[..., :self.kv_lora_rank]
k_pe = workspace[:toks]\
[..., self.kv_lora_rank:].unsqueeze(1)
kv_c_normed = workspace[:toks][..., :self.kv_lora_rank]
k_pe = workspace[:toks][..., self.kv_lora_rank:].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if use_fused_fp8_op:
from lightop import op
k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
q_attn, k_attn, v_attn = op.ds_fused_qkv_cast_fp8(
q,
kv_nope,
k_pe_expanded,
self.qk_nope_head_dim,
self.v_head_dim
)
else:
k_nope, v_nope = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
k_cat = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), dim=2)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
k_cat = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1])
q_descale = q_descale.expand(descale_shape)
k_descale = k_descale.expand(descale_shape)
v_descale = v_descale.expand(descale_shape)
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
k_cat = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
if use_flash_fp8_arch:
q_attn = q.to(torch.float8_e4m3fn)
k_attn = k_cat.to(torch.float8_e4m3fn)
v_attn = v_nope.to(torch.float8_e4m3fn)
else:
q_attn = q
k_attn = k_cat
v_attn = v_nope
if use_flash_fp8_arch:
attn_output, attn_softmax_lse = self._flash_attn_varlen_diff_headdims(
q=q_attn,
k=k_attn,
v=v_attn,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
attn_output, attn_softmax_lse = self._flash_attn_varlen_diff_headdims(
q=q_attn,
k=k_attn,
v=v_attn,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
......@@ -1124,6 +1132,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
kv_quant_args: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
......@@ -1132,34 +1141,54 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else:
has_context = False
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
if kv_quant_args is not None:
kv_nope = self.kv_b_proj.quant_method.apply(
self.kv_b_proj, kv_c_normed, input_quant_args=kv_quant_args)
else:
kv_nope = self.kv_b_proj(kv_c_normed)
if isinstance(kv_nope, tuple):
kv_nope = kv_nope[0]
kv_nope = kv_nope.view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
use_flash_fp8_arch = (
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938"
and envs.VLLM_USE_FLASH_ATTN_FP8
)
use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
if use_fused_fp8_op:
from lightop import op
k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
q, k, v = op.ds_fused_qkv_cast_fp8(
q,
kv_nope,
k_pe_expanded,
self.qk_nope_head_dim,
self.v_head_dim
)
else:
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), dim=2)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1])
q_descale = q_descale.expand(descale_shape)
k_descale = k_descale.expand(descale_shape)
v_descale = v_descale.expand(descale_shape)
if use_flash_fp8_arch:
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
if use_flash_fp8_arch:
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
......@@ -1170,9 +1199,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=has_context,
)
else:
......@@ -1270,7 +1299,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
prefill_k_pe = k_pe[num_decode_tokens:]
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
decode_q = q[:num_decode_tokens]
......@@ -1289,6 +1317,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else:
kv_cache_dtype_str = self.kv_cache_dtype
k_c_normed_int8 = None
k_c_normed_scale = None
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
......@@ -1301,8 +1331,31 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale,
)
else:
if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT:
k_c_normed_int8 = torch.empty((num_actual_toks, k_c_normed.size(-1)), dtype=torch.int8, device=q.device)
k_c_normed_scale = torch.empty((num_actual_toks, 1), dtype=torch.float32, device=q.device)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
if has_prefill:
if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT:
from lightop import op
op.fused_rms_norm_rope_int8quant_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
k_c_normed_int8,
k_c_normed_scale,
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
......@@ -1340,6 +1393,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False,
1e-6,
)
else:
if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT:
from lightop import op
op.fused_rms_norm_rope_int8quant_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
k_c_normed_int8,
k_c_normed_scale,
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
......@@ -1356,14 +1429,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False,
1e-6,
)
if has_prefill:
curr_kv_quant = None
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
prefill_k_c_normed = key_normed[:num_actual_toks, ...]
prefill_k_c_normed = prefill_k_c_normed[num_decode_tokens:]
else:
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT and prefill_k_c_normed is not None:
if k_c_normed_int8 is not None and k_c_normed_scale is not None:
curr_kv_quant = [k_c_normed_int8[num_decode_tokens:], k_c_normed_scale[num_decode_tokens:]]
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata, kv_scale=layer._k_scale)
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
kv_cache,
attn_metadata,
kv_scale=layer._k_scale,
kv_quant_args=curr_kv_quant
)
if has_decode:
assert attn_metadata.decode is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment