Commit d7bee8b6 authored by wanghl6's avatar wanghl6
Browse files

feat: 元宝 prefill融合算子优化

parent d761561a
...@@ -2184,13 +2184,26 @@ def gather_cache(src_cache: torch.Tensor, ...@@ -2184,13 +2184,26 @@ def gather_cache(src_cache: torch.Tensor,
) -> None: ) -> None:
#支持"kv cache fp8" 临时方案,带dtype的gather_cache在vllm0.10后会实现。 #支持"kv cache fp8" 临时方案,带dtype的gather_cache在vllm0.10后会实现。
if kv_dtype == "fp8" or kv_dtype == "fp8_e5m2" or kv_dtype == "fp8_e4m3": if kv_dtype == "fp8" or kv_dtype == "fp8_e5m2" or kv_dtype == "fp8_e4m3":
dst_fp8 = torch.empty(dst.shape, dtype=torch.uint8, device=dst.device) if not envs.VLLM_FUSED_GATHER_CACHE_CONVERT_FP8:
#convert_fp8(dst_fp8, dst, scale, kv_dtype) dst_fp8 = torch.empty(dst.shape, dtype=torch.uint8, device=dst.device)
torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table, #convert_fp8(dst_fp8, dst, scale, kv_dtype)
cu_seq_lens, batch_size, seq_starts) torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table,
#dst_fp8->bf16 cu_seq_lens, batch_size, seq_starts)
# convert_fp8(dst, dst_fp8, scale, kv_dtype) #dst_fp8->bf16
convert_fp8(dst, dst_fp8, 1.0, kv_dtype) # convert_fp8(dst, dst_fp8, scale, kv_dtype)
convert_fp8(dst, dst_fp8, 1.0, kv_dtype)
else:
from lightop import op
op.gather_convert_fp8_cache(
src_cache,
dst,
block_table,
cu_seq_lens,
batch_size,
scale,
kv_dtype,
seq_starts
)
else: else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts) cu_seq_lens, batch_size, seq_starts)
......
...@@ -219,7 +219,9 @@ if TYPE_CHECKING: ...@@ -219,7 +219,9 @@ if TYPE_CHECKING:
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_USE_FUSED_DTBMM: bool = False VLLM_USE_FUSED_DTBMM: bool = False
VLLM_FUSE_CAT_AND_CAST_FP8: bool = False
VLLM_FUSED_GATHER_CACHE_CONVERT_FP8: bool = False
VLLM_FUSED_RN_ROPE_INT8_QUANT: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
"XDG_CACHE_HOME", "XDG_CACHE_HOME",
...@@ -1404,6 +1406,15 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1404,6 +1406,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_DTBMM": "VLLM_USE_FUSED_DTBMM":
lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in
("true", "1")), ("true", "1")),
"VLLM_FUSE_CAT_AND_CAST_FP8":
lambda: (os.environ.get("VLLM_FUSE_CAT_AND_CAST_FP8", "False").lower() in
("true", "1")),
"VLLM_FUSED_GATHER_CACHE_CONVERT_FP8":
lambda: (os.environ.get("VLLM_FUSED_GATHER_CACHE_CONVERT_FP8", "False").lower() in
("true", "1")),
"VLLM_FUSED_RN_ROPE_INT8_QUANT":
lambda: (os.environ.get("VLLM_FUSED_RN_ROPE_INT8_QUANT", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -469,7 +469,7 @@ def apply_int8_linear( ...@@ -469,7 +469,7 @@ def apply_int8_linear(
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale. # * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: if (envs.USE_FUSED_RMS_QUANT or envs.VLLM_FUSED_RN_ROPE_INT8_QUANT) and input_quant_args is not None:
assert len(input_quant_args) == 2 assert len(input_quant_args) == 2
x_zp =None x_zp =None
x_q, x_scale = input_quant_args x_q, x_scale = input_quant_args
......
...@@ -1015,6 +1015,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1015,6 +1015,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
iters = len(prefill_metadata.chunked_context.seq_tot) iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace workspace = prefill_metadata.chunked_context.workspace
use_flash_fp8_arch = (
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938"
and envs.VLLM_USE_FLASH_ATTN_FP8
)
use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
for i in range(iters): for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i] toks = prefill_metadata.chunked_context.seq_tot[i]
...@@ -1029,62 +1035,64 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1029,62 +1035,64 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=kv_scale, scale=kv_scale,
) )
kv_c_normed = workspace[:toks]\ kv_c_normed = workspace[:toks][..., :self.kv_lora_rank]
[..., :self.kv_lora_rank] k_pe = workspace[:toks][..., self.kv_lora_rank:].unsqueeze(1)
k_pe = workspace[:toks]\
[..., self.kv_lora_rank:].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_OPT_CAT: if use_fused_fp8_op:
if k_nope.shape[0] > 1024: from lightop import op
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), q_attn, k_attn, v_attn = op.ds_fused_qkv_cast_fp8(
dim=2) q,
else: kv_nope,
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k_pe_expanded,
dim=-1) self.qk_nope_head_dim,
self.v_head_dim
)
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k_nope, v_nope = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
dim=-1)
if envs.VLLM_USE_OPT_CAT:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8: if k_nope.shape[0] > 1024:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) k_cat = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), dim=2)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) else:
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1]) k_cat = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
q_descale = q_descale.expand(descale_shape) else:
k_descale = k_descale.expand(descale_shape) k_cat = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
v_descale = v_descale.expand(descale_shape)
q = q.to(torch.float8_e4m3fn) if use_flash_fp8_arch:
k = k.to(torch.float8_e4m3fn) q_attn = q.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn) k_attn = k_cat.to(torch.float8_e4m3fn)
v_attn = v_nope.to(torch.float8_e4m3fn)
attn_output, attn_softmax_lse = \ else:
self._flash_attn_varlen_diff_headdims( q_attn = q
q=q, k_attn = k_cat
k=k, v_attn = v_nope
v=v,
if use_flash_fp8_arch:
attn_output, attn_softmax_lse = self._flash_attn_varlen_diff_headdims(
q=q_attn,
k=k_attn,
v=v_attn,
cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len, max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale, softmax_scale=self.scale,
causal=False, # Context is unmasked causal=False, # Context is unmasked
q_descale=q_descale, q_descale=None,
k_descale=k_descale, k_descale=None,
v_descale=v_descale, v_descale=None,
return_softmax_lse=True, return_softmax_lse=True,
) )
else: else:
attn_output, attn_softmax_lse = \ attn_output, attn_softmax_lse = self._flash_attn_varlen_diff_headdims(
self._flash_attn_varlen_diff_headdims( q=q_attn,
q=q, k=k_attn,
k=k, v=v_attn,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len, max_seqlen_q=prefill_metadata.max_query_len,
...@@ -1124,6 +1132,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1124,6 +1132,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32), kv_scale=torch.tensor(1.0, dtype=torch.float32),
kv_quant_args: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
...@@ -1131,35 +1140,55 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1131,35 +1140,55 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_context = attn_metadata.prefill.chunked_context is not None has_context = attn_metadata.prefill.chunked_context is not None
else: else:
has_context = False has_context = False
if kv_quant_args is not None:
kv_nope = self.kv_b_proj.quant_method.apply(
self.kv_b_proj, kv_c_normed, input_quant_args=kv_quant_args)
else:
kv_nope = self.kv_b_proj(kv_c_normed)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ if isinstance(kv_nope, tuple):
kv_nope = kv_nope[0]
kv_nope = kv_nope.view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) # kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: use_flash_fp8_arch = (
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938"
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), and envs.VLLM_USE_FLASH_ATTN_FP8
dim=2) )
else: use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1) if use_fused_fp8_op:
from lightop import op
k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
q, k, v = op.ds_fused_qkv_cast_fp8(
q,
kv_nope,
k_pe_expanded,
self.qk_nope_head_dim,
self.v_head_dim
)
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_OPT_CAT:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8: if k_nope.shape[0] > 1024:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), dim=2)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) else:
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1]) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
q_descale = q_descale.expand(descale_shape) else:
k_descale = k_descale.expand(descale_shape) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
v_descale = v_descale.expand(descale_shape)
if use_flash_fp8_arch:
q = q.to(torch.float8_e4m3fn) q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn) k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn) v = v.to(torch.float8_e4m3fn)
if use_flash_fp8_arch:
output = self._flash_attn_varlen_diff_headdims( output = self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
...@@ -1170,9 +1199,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1170,9 +1199,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
max_seqlen_k=attn_metadata.prefill.max_query_len, max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
q_descale=q_descale, q_descale=None,
k_descale=k_descale, k_descale=None,
v_descale=v_descale, v_descale=None,
return_softmax_lse=has_context, return_softmax_lse=has_context,
) )
else: else:
...@@ -1270,7 +1299,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1270,7 +1299,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:]
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
decode_q = q[:num_decode_tokens] decode_q = q[:num_decode_tokens]
...@@ -1289,6 +1317,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1289,6 +1317,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else: else:
kv_cache_dtype_str = self.kv_cache_dtype kv_cache_dtype_str = self.kv_cache_dtype
k_c_normed_int8 = None
k_c_normed_scale = None
# write the latent and rope to kv cache # write the latent and rope to kv cache
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
...@@ -1301,11 +1331,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1301,11 +1331,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale, scale=layer._k_scale,
) )
else: else:
if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT:
k_c_normed_int8 = torch.empty((num_actual_toks, k_c_normed.size(-1)), dtype=torch.int8, device=q.device)
k_c_normed_scale = torch.empty((num_actual_toks, 1), dtype=torch.float32, device=q.device)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA: if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
if has_prefill: if has_prefill:
fused_rms_norm_rope_contiguous( if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT:
from lightop import op
op.fused_rms_norm_rope_int8quant_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
k_c_normed_int8,
k_c_normed_scale,
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if has_decode:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device)
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device)
fuse_rmsnorm_rope_quant_qkv(
positions[:num_actual_toks, ...], positions[:num_actual_toks, ...],
query_nope,
q, q,
q_quant,
q_scale,
k_pe.squeeze(1), k_pe.squeeze(1),
k_c_normed, # not normed k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed key_normed[:num_actual_toks, ...], # normed
...@@ -1318,16 +1393,30 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1318,16 +1393,30 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False, False,
1e-6, 1e-6,
) )
if has_decode: else:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device) if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT:
q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device) from lightop import op
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device) op.fused_rms_norm_rope_int8quant_contiguous(
fuse_rmsnorm_rope_quant_qkv( positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
k_c_normed_int8,
k_c_normed_scale,
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...], positions[:num_actual_toks, ...],
query_nope,
q, q,
q_quant,
q_scale,
k_pe.squeeze(1), k_pe.squeeze(1),
k_c_normed, # not normed k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed key_normed[:num_actual_toks, ...], # normed
...@@ -1340,31 +1429,27 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1340,31 +1429,27 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False, False,
1e-6, 1e-6,
) )
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if has_prefill: if has_prefill:
curr_kv_quant = None
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
prefill_k_c_normed = key_normed[:num_actual_toks, ...] prefill_k_c_normed = key_normed[:num_actual_toks, ...]
prefill_k_c_normed = prefill_k_c_normed[num_decode_tokens:] prefill_k_c_normed = prefill_k_c_normed[num_decode_tokens:]
output[num_decode_tokens:] = self._forward_prefill( else:
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, prefill_k_c_normed = k_c_normed[num_decode_tokens:]
attn_metadata, kv_scale=layer._k_scale)
if envs.VLLM_FUSED_RN_ROPE_INT8_QUANT and prefill_k_c_normed is not None:
if k_c_normed_int8 is not None and k_c_normed_scale is not None:
curr_kv_quant = [k_c_normed_int8[num_decode_tokens:], k_c_normed_scale[num_decode_tokens:]]
output[num_decode_tokens:] = self._forward_prefill(
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
kv_cache,
attn_metadata,
kv_scale=layer._k_scale,
kv_quant_args=curr_kv_quant
)
if has_decode: if has_decode:
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and self.kv_cache_dtype=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA: if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and self.kv_cache_dtype=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment