Commit 5b7cf3ae authored by wangmin6's avatar wangmin6
Browse files

Merge branch '0.9.2-dev-tx-kernel_fuse' into 'v0.9.2-dev'

feat: 元宝 prefill融合算子优化

See merge request dcutoolkit/deeplearing/vllm!493
parents a34bff19 45517312
......@@ -2184,13 +2184,26 @@ def gather_cache(src_cache: torch.Tensor,
) -> None:
#支持"kv cache fp8" 临时方案,带dtype的gather_cache在vllm0.10后会实现。
if kv_dtype == "fp8" or kv_dtype == "fp8_e5m2" or kv_dtype == "fp8_e4m3":
dst_fp8 = torch.empty(dst.shape, dtype=torch.uint8, device=dst.device)
#convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table,
cu_seq_lens, batch_size, seq_starts)
#dst_fp8->bf16
# convert_fp8(dst, dst_fp8, scale, kv_dtype)
convert_fp8(dst, dst_fp8, 1.0, kv_dtype)
if not envs.VLLM_FUSED_GATHER_CACHE_CONVERT_FP8:
dst_fp8 = torch.empty(dst.shape, dtype=torch.uint8, device=dst.device)
#convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table,
cu_seq_lens, batch_size, seq_starts)
#dst_fp8->bf16
# convert_fp8(dst, dst_fp8, scale, kv_dtype)
convert_fp8(dst, dst_fp8, 1.0, kv_dtype)
else:
from lightop import op
op.gather_convert_fp8_cache(
src_cache,
dst,
block_table,
cu_seq_lens,
batch_size,
scale,
kv_dtype,
seq_starts
)
else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
......
......@@ -220,7 +220,8 @@ if TYPE_CHECKING:
VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_USE_FUSED_DTBMM: bool = False
VLLM_FUSE_CAT_AND_CAST_FP8: bool = False
VLLM_FUSED_GATHER_CACHE_CONVERT_FP8: bool = False
VLLM_FUSED_RN_ROPE_INT8_QUANT: bool = False
def get_default_cache_root():
return os.getenv(
"XDG_CACHE_HOME",
......@@ -1408,6 +1409,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_FUSE_CAT_AND_CAST_FP8":
lambda: (os.environ.get("VLLM_FUSE_CAT_AND_CAST_FP8", "False").lower() in
("true", "1")),
"VLLM_FUSED_GATHER_CACHE_CONVERT_FP8":
lambda: (os.environ.get("VLLM_FUSED_GATHER_CACHE_CONVERT_FP8", "False").lower() in
("true", "1")),
"VLLM_FUSED_RN_ROPE_INT8_QUANT":
lambda: (os.environ.get("VLLM_FUSED_RN_ROPE_INT8_QUANT", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -469,7 +469,7 @@ def apply_int8_linear(
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
if (envs.USE_FUSED_RMS_QUANT or envs.VLLM_FUSED_RN_ROPE_INT8_QUANT) and input_quant_args is not None:
assert len(input_quant_args) == 2
x_zp =None
x_q, x_scale = input_quant_args
......
......@@ -1015,6 +1015,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
use_flash_fp8_arch = (
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938"
and envs.VLLM_USE_FLASH_ATTN_FP8
)
use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
......@@ -1029,24 +1035,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=kv_scale,
)
kv_c_normed = workspace[:toks]\
[..., :self.kv_lora_rank]
k_pe = workspace[:toks]\
[..., self.kv_lora_rank:].unsqueeze(1)
kv_c_normed = workspace[:toks][..., :self.kv_lora_rank]
k_pe = workspace[:toks][..., self.kv_lora_rank:].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
use_flash_fp8_arch = ( \
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" \
and envs.VLLM_USE_FLASH_ATTN_FP8
)
use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
if use_fused_fp8_op:
from lightop import op
q, k, v = op.ds_fused_qkv_cast_fp8(
k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
q_attn, k_attn, v_attn = op.ds_fused_qkv_cast_fp8(
q,
kv_nope,
k_pe_expanded,
......@@ -1054,48 +1052,47 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.v_head_dim
)
else:
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_nope, v_nope = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe_expanded, dim=2)
k_cat = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), dim=2)
else:
k = torch.cat((k_nope, k_pe_expanded), dim=-1)
k_cat = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
else:
k_cat = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
if use_flash_fp8_arch:
q_attn = q.to(torch.float8_e4m3fn)
k_attn = k_cat.to(torch.float8_e4m3fn)
v_attn = v_nope.to(torch.float8_e4m3fn)
else:
k = torch.cat((k_nope, k_pe_expanded), dim=-1)
q_attn = q
k_attn = k_cat
v_attn = v_nope
if use_flash_fp8_arch:
q_descale = None
k_descale = None
v_descale = None
if not use_fused_fp8_op:
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
attn_output, attn_softmax_lse = self._flash_attn_varlen_diff_headdims(
q=q_attn,
k=k_attn,
v=v_attn,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
attn_output, attn_softmax_lse = self._flash_attn_varlen_diff_headdims(
q=q_attn,
k=k_attn,
v=v_attn,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
......@@ -1135,6 +1132,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
kv_quant_args: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
......@@ -1142,15 +1140,28 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_context = attn_metadata.prefill.chunked_context is not None
else:
has_context = False
if kv_quant_args is not None:
kv_nope = self.kv_b_proj.quant_method.apply(
self.kv_b_proj, kv_c_normed, input_quant_args=kv_quant_args)
else:
kv_nope = self.kv_b_proj(kv_c_normed)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
if isinstance(kv_nope, tuple):
kv_nope = kv_nope[0]
kv_nope = kv_nope.view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
use_flash_fp8_arch = ( \
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" \
# kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
use_flash_fp8_arch = (
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938"
and envs.VLLM_USE_FLASH_ATTN_FP8
)
use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
if use_fused_fp8_op:
from lightop import op
k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
......@@ -1172,14 +1183,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
if use_flash_fp8_arch:
q_descale = None
k_descale = None
v_descale = None
if not use_fused_fp8_op:
if use_flash_fp8_arch:
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
if use_flash_fp8_arch:
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
......@@ -1190,9 +1199,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=has_context,
)
else:
......@@ -1308,6 +1317,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else:
kv_cache_dtype_str = self.kv_cache_dtype
k_c_normed_int8 = None
k_c_normed_scale = None
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
......@@ -1320,11 +1331,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale,
)
else:
if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT:
k_c_normed_int8 = torch.empty((num_actual_toks, k_c_normed.size(-1)), dtype=torch.int8, device=q.device)
k_c_normed_scale = torch.empty((num_actual_toks, 1), dtype=torch.float32, device=q.device)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
if has_prefill:
fused_rms_norm_rope_contiguous(
if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT:
from lightop import op
op.fused_rms_norm_rope_int8quant_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
k_c_normed_int8,
k_c_normed_scale,
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if has_decode:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device)
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device)
fuse_rmsnorm_rope_quant_qkv(
positions[:num_actual_toks, ...],
query_nope,
q,
q_quant,
q_scale,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
......@@ -1337,16 +1393,30 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False,
1e-6,
)
if has_decode:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device)
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device)
fuse_rmsnorm_rope_quant_qkv(
else:
if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT:
from lightop import op
op.fused_rms_norm_rope_int8quant_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
k_c_normed_int8,
k_c_normed_scale,
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
query_nope,
q,
q_quant,
q_scale,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
......@@ -1359,30 +1429,27 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if has_prefill:
curr_kv_quant = None
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
prefill_k_c_normed = key_normed[:num_actual_toks, ...]
prefill_k_c_normed = prefill_k_c_normed[num_decode_tokens:]
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata, kv_scale=layer._k_scale)
else:
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT and prefill_k_c_normed is not None:
if k_c_normed_int8 is not None and k_c_normed_scale is not None:
curr_kv_quant = [k_c_normed_int8[num_decode_tokens:], k_c_normed_scale[num_decode_tokens:]]
output[num_decode_tokens:] = self._forward_prefill(
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
kv_cache,
attn_metadata,
kv_scale=layer._k_scale,
kv_quant_args=curr_kv_quant
)
if has_decode:
assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and self.kv_cache_dtype=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment