Unverified Commit 9ff9fa7f authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Fuse wk and weight_proj in Indexer for DeepSeekV3.2-FP4 (#12094)

parent 7ed8ba05
......@@ -119,6 +119,7 @@ class Indexer(CustomOp):
prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
fuse_wk_and_weights_proj: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
......@@ -129,6 +130,7 @@ class Indexer(CustomOp):
self.q_lora_rank = q_lora_rank
self.layer_id = layer_id
self.alt_stream = alt_stream
self.fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
if is_cuda():
self.sm_count = deep_gemm.get_num_sms()
self.half_device_sm_count = align(self.sm_count // 2, 8)
......@@ -140,21 +142,29 @@ class Indexer(CustomOp):
quant_config=quant_config,
prefix=add_prefix("wq_b", prefix),
)
self.wk = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wk", prefix),
)
if self.fuse_wk_and_weights_proj:
self.fused_wk_and_weights_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim + self.n_heads,
bias=False,
prefix=add_prefix("fused_wk_and_weights_proj", prefix),
)
else:
self.wk = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wk", prefix),
)
# NOTE: weight_proj is not quantized
self.weights_proj = ReplicatedLinear(
self.hidden_size,
self.n_heads,
bias=False,
prefix=add_prefix("weights_proj", prefix),
)
self.k_norm = V32LayerNorm(self.head_dim)
# NOTE: weight_proj is not quantized
self.weights_proj = ReplicatedLinear(
self.hidden_size,
self.n_heads,
bias=False,
prefix=add_prefix("weights_proj", prefix),
)
self.rotary_emb = get_rope_wrapper(
rope_head_dim,
rotary_dim=rope_head_dim,
......@@ -169,8 +179,7 @@ class Indexer(CustomOp):
self.softmax_scale = self.head_dim**-0.5
@torch.compile(dynamic=True)
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x)
def _get_logits_head_gate(self, weights: torch.Tensor, q_scale: torch.Tensor):
weights = weights * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
......@@ -182,7 +191,7 @@ class Indexer(CustomOp):
positions: torch.Tensor,
enable_dual_stream: bool,
):
weights = None
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
......@@ -199,7 +208,12 @@ class Indexer(CustomOp):
)
with torch.cuda.stream(self.alt_stream):
# TODO we should also put DeepGEMM half SM here?
key, _ = self.wk(x)
if self.fuse_wk_and_weights_proj:
key, weights = self.fused_wk_and_weights_proj(x)[0].split(
[self.head_dim, self.n_heads], dim=-1
)
else:
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
......@@ -217,7 +231,12 @@ class Indexer(CustomOp):
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
key, _ = self.wk(x)
if self.fuse_wk_and_weights_proj:
key, weights = self.fused_wk_and_weights_proj(x)[0].split(
[self.head_dim, self.n_heads], dim=-1
)
else:
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
......@@ -240,7 +259,7 @@ class Indexer(CustomOp):
query = rotate_activation(query)
key = rotate_activation(key)
return query, key
return query, key, weights
def _get_topk_paged(
self,
......@@ -490,7 +509,9 @@ class Indexer(CustomOp):
if metadata is None:
return None
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
query, key, weights = self._get_q_k_bf16(
q_lora, x, positions, enable_dual_stream
)
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
......@@ -517,7 +538,9 @@ class Indexer(CustomOp):
index_k_scale=k_scale,
)
weights = self._get_logits_head_gate(x, q_scale)
if not self.fuse_wk_and_weights_proj:
weights, _ = self.weights_proj(x)
weights = self._get_logits_head_gate(weights, q_scale)
if is_cuda():
assert forward_batch.seq_lens_cpu is not None
......
......@@ -224,6 +224,17 @@ def add_forward_absorb_core_attention_backend(backend_name):
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
def is_nsa_indexer_wk_and_weights_proj_fused(config, quant_config):
"""
NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
"""
return (
is_deepseek_nsa(config)
and quant_config is not None
and quant_config.get_name() == "modelopt_fp4"
)
class AttnForwardMethod(IntEnum):
# Use multi-head attention
MHA = auto()
......@@ -1143,6 +1154,9 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config=quant_config,
layer_id=layer_id,
alt_stream=alt_stream,
fuse_wk_and_weights_proj=is_nsa_indexer_wk_and_weights_proj_fused(
config, quant_config
),
)
self.kv_b_proj = ColumnParallelLinear(
......@@ -3413,6 +3427,10 @@ class DeepseekV2ForCausalLM(nn.Module):
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
fuse_wk_and_weights_proj = is_nsa_indexer_wk_and_weights_proj_fused(
self.config, self.quant_config
)
cached_wk_and_weights_proj = {} if fuse_wk_and_weights_proj else None
if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
......@@ -3584,6 +3602,53 @@ class DeepseekV2ForCausalLM(nn.Module):
)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
elif fuse_wk_and_weights_proj and (
"wk" in name or "weights_proj" in name
):
cached_wk_and_weights_proj[name] = loaded_weight
wk_name = (
name
if "wk" in name
else name.replace("weights_proj", "wk")
)
weights_proj_name = (
name
if "weights_proj" in name
else name.replace("wk", "weights_proj")
)
# When both wk and weights_proj has been cached, load the fused weight to parameter
if (
wk_name in cached_wk_and_weights_proj
and weights_proj_name in cached_wk_and_weights_proj
):
wk_weight = cached_wk_and_weights_proj[wk_name]
weights_proj_weight = cached_wk_and_weights_proj[
weights_proj_name
]
# todo dequantize wk for fp8
assert wk_weight.dtype == weights_proj_weight.dtype
fused_weight = torch.cat(
[wk_weight, weights_proj_weight], dim=0
)
param_name = (
name.replace("wk", "fused_wk_and_weights_proj")
if "wk" in name
else name.replace(
"weights_proj",
"fused_wk_and_weights_proj",
)
)
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
futures.append(
executor.submit(weight_loader, param, fused_weight)
)
cached_wk_and_weights_proj.pop(wk_name)
cached_wk_and_weights_proj.pop(weights_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment