Commit f441aca2 authored by zhuwenwen's avatar zhuwenwen
Browse files

update mqa_logits and paged_mqa_logits

parent cc7715fd
...@@ -83,6 +83,7 @@ from vllm import _custom_ops as ops ...@@ -83,6 +83,7 @@ from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
if current_platform.is_rocm(): if current_platform.is_rocm():
import lightop
from lightop import op, gemmopt from lightop import op, gemmopt
else: else:
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
...@@ -601,33 +602,35 @@ def sparse_attn_indexer( ...@@ -601,33 +602,35 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache( if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k, ops.indexer_k_quant_and_cache(
kv_cache, k,
slot_mapping, kv_cache,
quant_block_size, slot_mapping,
scale_fmt, quant_block_size,
) scale_fmt,
)
topk_indices_buffer[:hidden_states.shape[0]] = -1 topk_indices_buffer[:hidden_states.shape[0]] = -1
if has_prefill: if has_prefill:
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
for chunk in prefill_metadata.chunks: for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim], if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
device=k.device, device=k.device,
dtype=torch.float8_e4m3fn) dtype=torch.float8_e4m3fn)
k_scale = torch.empty([chunk.total_seq_lens, 1], k_scale = torch.empty([chunk.total_seq_lens, 1],
device=k.device, device=k.device,
dtype=torch.float32) dtype=torch.float32)
cp_gather_indexer_k_quant_cache( cp_gather_indexer_k_quant_cache(
kv_cache, kv_cache,
k_fp8, k_fp8,
k_scale, k_scale,
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
chunk.num_reqs, chunk.num_reqs,
) )
if not current_platform.is_rocm():
logits = fp8_mqa_logits( logits = fp8_mqa_logits(
q_fp8[chunk.token_start:chunk.token_end], q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale), (k_fp8, k_scale),
...@@ -637,12 +640,18 @@ def sparse_attn_indexer( ...@@ -637,12 +640,18 @@ def sparse_attn_indexer(
) )
else: else:
logits = op.mqa_logits( logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end].half(), q_fp8[chunk.token_start:chunk.token_end],
(k_fp8.half(), k_scale), k,
weights[chunk.token_start:chunk.token_end], weights[chunk.token_start:chunk.token_end] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
) q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1] dim=-1)[1]
topk_indices -= chunk.cu_seqlen_ks[:, None] topk_indices -= chunk.cu_seqlen_ks[:, None]
...@@ -692,14 +701,15 @@ def sparse_attn_indexer( ...@@ -692,14 +701,15 @@ def sparse_attn_indexer(
) )
else: else:
logits = gemmopt.paged_mqa_logits( logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens.half(), padded_q_fp8_decode_tokens,
kv_cache.half(), kv_cache if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else kv_cache.to(torch.bfloat16),
weights[:num_padded_tokens], weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32),
decode_metadata.seq_lens, decode_metadata.seq_lens,
decode_metadata.block_table, decode_metadata.block_table,
decode_metadata.schedule_metadata, decode_metadata.schedule_metadata,
max_context_len=max_model_len, max_model_len,
) )
# padded query len # padded query len
current_device = padded_q_fp8_decode_tokens.device current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n padded_num_tokens = batch_size * next_n
...@@ -753,12 +763,13 @@ def sparse_attn_indexer_fake( ...@@ -753,12 +763,13 @@ def sparse_attn_indexer_fake(
# profile run # profile run
# NOTE(Chen): create the max possible flattened_kv. So that # NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage. # profile_run can get correct memory usage.
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4], if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
device=k.device, _flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
dtype=torch.uint8) device=k.device,
_k_fp8 = _flattened_kv[..., :head_dim].view( dtype=torch.uint8)
torch.float8_e4m3fn).contiguous() _k_fp8 = _flattened_kv[..., :head_dim].view(
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() torch.float8_e4m3fn).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer return topk_indices_buffer
...@@ -845,35 +856,54 @@ class Indexer(nn.Module): ...@@ -845,35 +856,54 @@ class Indexer(nn.Module):
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion # we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim) if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q_fp8, q_scale = per_token_group_quant_fp8(q, q = q.view(-1, self.head_dim)
self.quant_block_size, q_fp8, q_scale = per_token_group_quant_fp8(q,
column_major_scales=False, self.quant_block_size,
use_ue8m0=self.scale_fmt column_major_scales=False,
is not None) use_ue8m0=self.scale_fmt
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) is not None)
q_scale = q_scale.view(-1, self.n_head, 1) q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states) weights, _ = self.weights_proj(hidden_states)
weights = weights.unsqueeze( if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
-1) * q_scale * self.softmax_scale * self.n_head**-0.5 weights = weights.unsqueeze(
weights = weights.squeeze(-1) -1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights = weights.squeeze(-1)
return torch.ops.vllm.sparse_attn_indexer(
hidden_states, if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
self.k_cache.prefix, return torch.ops.vllm.sparse_attn_indexer(
self.k_cache.kv_cache[0], hidden_states,
q_fp8, self.k_cache.prefix,
k, self.k_cache.kv_cache[0],
weights, q_fp8,
self.quant_block_size, k,
self.scale_fmt, weights,
self.topk_tokens, self.quant_block_size,
self.head_dim, self.scale_fmt,
self.max_model_len, self.topk_tokens,
self.max_total_seq_len, self.head_dim,
self.topk_indices_buffer, self.max_model_len,
) self.max_total_seq_len,
self.topk_indices_buffer,
)
else:
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
class DeepseekV2MLAAttention(nn.Module): class DeepseekV2MLAAttention(nn.Module):
...@@ -1583,4 +1613,4 @@ def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config, ...@@ -1583,4 +1613,4 @@ def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config,
for i in range(config.num_nextn_predict_layers): for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."): if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i return layer_idx + i
return None return None
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment