Unverified Commit ac3dac54 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix][Perf] Indexer upcast WK to BF16 for fusion (#38928)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent 39ac6404
...@@ -30,6 +30,7 @@ from .deepseek_v2 import ( ...@@ -30,6 +30,7 @@ from .deepseek_v2 import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV2MixtureOfExperts, DeepseekV2MixtureOfExperts,
DeepseekV2MoE, DeepseekV2MoE,
_try_load_fp8_indexer_wk,
get_spec_layer_idx_from_weight_name, get_spec_layer_idx_from_weight_name,
) )
from .utils import maybe_prefix from .utils import maybe_prefix
...@@ -190,10 +191,6 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -190,10 +191,6 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
) )
# Set MoE hyperparameters # Set MoE hyperparameters
self.set_moe_parameters() self.set_moe_parameters()
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)
def set_moe_parameters(self): def set_moe_parameters(self):
self.expert_weights = [] self.expert_weights = []
...@@ -248,13 +245,12 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -248,13 +245,12 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
] ]
if self.is_fp4_ckpt: # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) indexer_fused_mapping = [
indexer_fused_mapping = [ ("wk_weights_proj", "wk", 0),
("wk_weights_proj", "wk", 0), ("wk_weights_proj", "weights_proj", 1),
("wk_weights_proj", "weights_proj", 1), ]
] stacked_params_mapping.extend(indexer_fused_mapping)
stacked_params_mapping.extend(indexer_fused_mapping)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
self, self,
...@@ -271,6 +267,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -271,6 +267,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
_pending_wk_fp8: dict = {} # FP8 indexer wk dequant buffer
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -281,6 +278,12 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -281,6 +278,12 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
) )
name = self._rewrite_spec_layer_name(spec_layer, name) name = self._rewrite_spec_layer_name(spec_layer, name)
if _try_load_fp8_indexer_wk(
name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params
):
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:
......
...@@ -66,6 +66,10 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -66,6 +66,10 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
scaled_dequantize,
)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sparse_attn_indexer import ( from vllm.model_executor.layers.sparse_attn_indexer import (
SparseAttnIndexer, SparseAttnIndexer,
...@@ -628,10 +632,6 @@ class Indexer(nn.Module): ...@@ -628,10 +632,6 @@ class Indexer(nn.Module):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self.topk_tokens = config.index_topk self.topk_tokens = config.index_topk
self.n_head = config.index_n_heads # 64 self.n_head = config.index_n_heads # 64
...@@ -646,36 +646,16 @@ class Indexer(nn.Module): ...@@ -646,36 +646,16 @@ class Indexer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.wq_b", prefix=f"{prefix}.wq_b",
) )
if self.is_fp4_ckpt: # Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. # FP8 wk weights are upcasted to BF16 during loading to maintain fusion.
# weights_proj does not get quantized, self.wk_weights_proj = MergedColumnParallelLinear(
# so we run both with quant_config=None hidden_size,
# wk may be upcasted from the default quant; [self.head_dim, self.n_head],
# experiments show fusion is always faster unless WK proj is in FP4, bias=False,
# which is not the case for all known quants. quant_config=None,
self.wk_weights_proj = MergedColumnParallelLinear( disable_tp=True,
hidden_size, prefix=f"{prefix}.wk_weights_proj",
[self.head_dim, self.n_head], )
bias=False,
quant_config=None,
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
else:
self.wk = ReplicatedLinear(
hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5 self.softmax_scale = self.head_dim**-0.5
...@@ -716,14 +696,10 @@ class Indexer(nn.Module): ...@@ -716,14 +696,10 @@ class Indexer(nn.Module):
q_pe, q_nope = torch.split( q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
) )
if self.is_fp4_ckpt: # Fused wk + weights_proj: one GEMM, then split
# Fused wk + weights_proj: one GEMM, then split kw, _ = self.wk_weights_proj(hidden_states)
kw, _ = self.wk_weights_proj(hidden_states) k = kw[:, : self.head_dim]
k = kw[:, : self.head_dim] weights = kw[:, self.head_dim :]
weights = kw[:, self.head_dim :]
else:
k, _ = self.wk(hidden_states)
weights, _ = self.weights_proj(hidden_states)
k = self.k_norm(k) k = self.k_norm(k)
k_pe, k_nope = torch.split( k_pe, k_nope = torch.split(
...@@ -761,6 +737,46 @@ class Indexer(nn.Module): ...@@ -761,6 +737,46 @@ class Indexer(nn.Module):
return self.indexer_op(hidden_states, q_fp8, k, weights) return self.indexer_op(hidden_states, q_fp8, k, weights)
def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params):
"""
We fuse the WK and weights_proj projections, but in some checkpoints WK is stored
in FP8 with a separate weight_scale_inv, while weights_proj is stored in BF16.
Upcasting to BF16 during loading enables the fusion. This function loads the FP8 WK
weights and scale, and when both are available, dequantizes to BF16 and stores into
the fused wk_weights_proj.weight parameter.
"""
if "indexer.wk." not in name or "wk_weights" in name:
return False # Weight is not an isolated WK weight for the indexer, ignore.
is_weight = name.endswith(".weight") and tensor.dtype == torch.float8_e4m3fn
is_scale = "weight_scale_inv" in name
if not is_weight and not is_scale:
return False # WK is not in FP8 format, ignore.
# Buffer this tensor (weight or scale) until both have arrived.
layer_prefix = name.rsplit(".wk.", 1)[0] # e.g. "model.layers.0.self_attn.indexer"
entry = buf.setdefault(layer_prefix, {})
entry["weight" if is_weight else "scale"] = tensor
if "weight" not in entry or "scale" not in entry:
return True # still waiting for the other param
# We have both weight and scale: dequantize FP8 to BF16.
weight_fp8, scale_inv = entry["weight"], entry["scale"]
del buf[layer_prefix]
block_size = weight_fp8.shape[1] // scale_inv.shape[1]
weight_bf16 = scaled_dequantize(
weight_fp8,
scale_inv,
group_shape=GroupShape(block_size, block_size),
out_dtype=torch.bfloat16,
)
# Load the dequantized weight into shard 0 of the fused buffer.
fused_name = f"{layer_prefix}.wk_weights_proj.weight"
param = params_dict[fused_name]
param.weight_loader(param, weight_bf16, 0)
loaded_params.add(fused_name)
return True
def _min_latency_fused_qkv_a_proj_impl( def _min_latency_fused_qkv_a_proj_impl(
input_: torch.Tensor, input_: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
...@@ -1344,10 +1360,6 @@ class DeepseekV2ForCausalLM( ...@@ -1344,10 +1360,6 @@ class DeepseekV2ForCausalLM(
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
...@@ -1473,13 +1485,13 @@ class DeepseekV2ForCausalLM( ...@@ -1473,13 +1485,13 @@ class DeepseekV2ForCausalLM(
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
if self.is_fp4_ckpt: # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) _pending_wk_fp8: dict = {} # When WK is in FP8, we dequant to BF16 for fusion
indexer_fused_mapping = [ indexer_fused_mapping = [
("wk_weights_proj", "wk", 0), ("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1), ("wk_weights_proj", "weights_proj", 1),
] ]
stacked_params_mapping.extend(indexer_fused_mapping) stacked_params_mapping.extend(indexer_fused_mapping)
if self.use_mha: if self.use_mha:
stacked_params_mapping.extend(mha_params_mapping) stacked_params_mapping.extend(mha_params_mapping)
...@@ -1516,6 +1528,11 @@ class DeepseekV2ForCausalLM( ...@@ -1516,6 +1528,11 @@ class DeepseekV2ForCausalLM(
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
) )
if _try_load_fp8_indexer_wk(
name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params
):
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment