Commit f441aca2 authored by zhuwenwen's avatar zhuwenwen
Browse files

update mqa_logits and paged_mqa_logits

parent cc7715fd
......@@ -83,6 +83,7 @@ from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
if current_platform.is_rocm():
import lightop
from lightop import op, gemmopt
else:
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
......@@ -601,33 +602,35 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[:hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
device=k.device,
dtype=torch.float8_e4m3fn)
k_scale = torch.empty([chunk.total_seq_lens, 1],
device=k.device,
dtype=torch.float32)
cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
chunk.num_reqs,
)
if not current_platform.is_rocm():
k_scale = torch.empty([chunk.total_seq_lens, 1],
device=k.device,
dtype=torch.float32)
cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
chunk.num_reqs,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale),
......@@ -637,12 +640,18 @@ def sparse_attn_indexer(
)
else:
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end].half(),
(k_fp8.half(), k_scale),
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
q_fp8[chunk.token_start:chunk.token_end],
k,
weights[chunk.token_start:chunk.token_end] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1]
topk_indices -= chunk.cu_seqlen_ks[:, None]
......@@ -692,14 +701,15 @@ def sparse_attn_indexer(
)
else:
logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens.half(),
kv_cache.half(),
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_context_len=max_model_len,
padded_q_fp8_decode_tokens,
kv_cache if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else kv_cache.to(torch.bfloat16),
weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32),
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len,
)
# padded query len
current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n
......@@ -753,12 +763,13 @@ def sparse_attn_indexer_fake(
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
device=k.device,
dtype=torch.uint8)
_k_fp8 = _flattened_kv[..., :head_dim].view(
torch.float8_e4m3fn).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
device=k.device,
dtype=torch.uint8)
_k_fp8 = _flattened_kv[..., :head_dim].view(
torch.float8_e4m3fn).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer
......@@ -845,35 +856,54 @@ class Indexer(nn.Module):
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt
is not None)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt
is not None)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states)
weights = weights.unsqueeze(
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights = weights.squeeze(-1)
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
weights = weights.unsqueeze(
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights = weights.squeeze(-1)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
else:
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
class DeepseekV2MLAAttention(nn.Module):
......@@ -1583,4 +1613,4 @@ def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config,
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i
return None
return None
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment