Unverified Commit 9ff9fa7f authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Fuse wk and weight_proj in Indexer for DeepSeekV3.2-FP4 (#12094)

parent 7ed8ba05
...@@ -119,6 +119,7 @@ class Indexer(CustomOp): ...@@ -119,6 +119,7 @@ class Indexer(CustomOp):
prefix: str = "", prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None, alt_stream: Optional[torch.cuda.Stream] = None,
fuse_wk_and_weights_proj: bool = False,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -129,6 +130,7 @@ class Indexer(CustomOp): ...@@ -129,6 +130,7 @@ class Indexer(CustomOp):
self.q_lora_rank = q_lora_rank self.q_lora_rank = q_lora_rank
self.layer_id = layer_id self.layer_id = layer_id
self.alt_stream = alt_stream self.alt_stream = alt_stream
self.fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
if is_cuda(): if is_cuda():
self.sm_count = deep_gemm.get_num_sms() self.sm_count = deep_gemm.get_num_sms()
self.half_device_sm_count = align(self.sm_count // 2, 8) self.half_device_sm_count = align(self.sm_count // 2, 8)
...@@ -140,6 +142,14 @@ class Indexer(CustomOp): ...@@ -140,6 +142,14 @@ class Indexer(CustomOp):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("wq_b", prefix), prefix=add_prefix("wq_b", prefix),
) )
if self.fuse_wk_and_weights_proj:
self.fused_wk_and_weights_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim + self.n_heads,
bias=False,
prefix=add_prefix("fused_wk_and_weights_proj", prefix),
)
else:
self.wk = ReplicatedLinear( self.wk = ReplicatedLinear(
self.hidden_size, self.hidden_size,
self.head_dim, self.head_dim,
...@@ -147,7 +157,6 @@ class Indexer(CustomOp): ...@@ -147,7 +157,6 @@ class Indexer(CustomOp):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("wk", prefix), prefix=add_prefix("wk", prefix),
) )
self.k_norm = V32LayerNorm(self.head_dim)
# NOTE: weight_proj is not quantized # NOTE: weight_proj is not quantized
self.weights_proj = ReplicatedLinear( self.weights_proj = ReplicatedLinear(
self.hidden_size, self.hidden_size,
...@@ -155,6 +164,7 @@ class Indexer(CustomOp): ...@@ -155,6 +164,7 @@ class Indexer(CustomOp):
bias=False, bias=False,
prefix=add_prefix("weights_proj", prefix), prefix=add_prefix("weights_proj", prefix),
) )
self.k_norm = V32LayerNorm(self.head_dim)
self.rotary_emb = get_rope_wrapper( self.rotary_emb = get_rope_wrapper(
rope_head_dim, rope_head_dim,
rotary_dim=rope_head_dim, rotary_dim=rope_head_dim,
...@@ -169,8 +179,7 @@ class Indexer(CustomOp): ...@@ -169,8 +179,7 @@ class Indexer(CustomOp):
self.softmax_scale = self.head_dim**-0.5 self.softmax_scale = self.head_dim**-0.5
@torch.compile(dynamic=True) @torch.compile(dynamic=True)
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): def _get_logits_head_gate(self, weights: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x)
weights = weights * self.n_heads**-0.5 weights = weights * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights return weights
...@@ -182,7 +191,7 @@ class Indexer(CustomOp): ...@@ -182,7 +191,7 @@ class Indexer(CustomOp):
positions: torch.Tensor, positions: torch.Tensor,
enable_dual_stream: bool, enable_dual_stream: bool,
): ):
weights = None
if enable_dual_stream: if enable_dual_stream:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream) self.alt_stream.wait_stream(current_stream)
...@@ -199,6 +208,11 @@ class Indexer(CustomOp): ...@@ -199,6 +208,11 @@ class Indexer(CustomOp):
) )
with torch.cuda.stream(self.alt_stream): with torch.cuda.stream(self.alt_stream):
# TODO we should also put DeepGEMM half SM here? # TODO we should also put DeepGEMM half SM here?
if self.fuse_wk_and_weights_proj:
key, weights = self.fused_wk_and_weights_proj(x)[0].split(
[self.head_dim, self.n_heads], dim=-1
)
else:
key, _ = self.wk(x) key, _ = self.wk(x)
key = self.k_norm(key) key = self.k_norm(key)
...@@ -217,6 +231,11 @@ class Indexer(CustomOp): ...@@ -217,6 +231,11 @@ class Indexer(CustomOp):
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1 query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
) )
if self.fuse_wk_and_weights_proj:
key, weights = self.fused_wk_and_weights_proj(x)[0].split(
[self.head_dim, self.n_heads], dim=-1
)
else:
key, _ = self.wk(x) key, _ = self.wk(x)
key = self.k_norm(key) key = self.k_norm(key)
k_rope, _ = torch.split( k_rope, _ = torch.split(
...@@ -240,7 +259,7 @@ class Indexer(CustomOp): ...@@ -240,7 +259,7 @@ class Indexer(CustomOp):
query = rotate_activation(query) query = rotate_activation(query)
key = rotate_activation(key) key = rotate_activation(key)
return query, key return query, key, weights
def _get_topk_paged( def _get_topk_paged(
self, self,
...@@ -490,7 +509,9 @@ class Indexer(CustomOp): ...@@ -490,7 +509,9 @@ class Indexer(CustomOp):
if metadata is None: if metadata is None:
return None return None
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream) query, key, weights = self._get_q_k_bf16(
q_lora, x, positions, enable_dual_stream
)
if enable_dual_stream: if enable_dual_stream:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
...@@ -517,7 +538,9 @@ class Indexer(CustomOp): ...@@ -517,7 +538,9 @@ class Indexer(CustomOp):
index_k_scale=k_scale, index_k_scale=k_scale,
) )
weights = self._get_logits_head_gate(x, q_scale) if not self.fuse_wk_and_weights_proj:
weights, _ = self.weights_proj(x)
weights = self._get_logits_head_gate(weights, q_scale)
if is_cuda(): if is_cuda():
assert forward_batch.seq_lens_cpu is not None assert forward_batch.seq_lens_cpu is not None
......
...@@ -224,6 +224,17 @@ def add_forward_absorb_core_attention_backend(backend_name): ...@@ -224,6 +224,17 @@ def add_forward_absorb_core_attention_backend(backend_name):
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.") logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
def is_nsa_indexer_wk_and_weights_proj_fused(config, quant_config):
"""
NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
"""
return (
is_deepseek_nsa(config)
and quant_config is not None
and quant_config.get_name() == "modelopt_fp4"
)
class AttnForwardMethod(IntEnum): class AttnForwardMethod(IntEnum):
# Use multi-head attention # Use multi-head attention
MHA = auto() MHA = auto()
...@@ -1143,6 +1154,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1143,6 +1154,9 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
alt_stream=alt_stream, alt_stream=alt_stream,
fuse_wk_and_weights_proj=is_nsa_indexer_wk_and_weights_proj_fused(
config, quant_config
),
) )
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -3413,6 +3427,10 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3413,6 +3427,10 @@ class DeepseekV2ForCausalLM(nn.Module):
self.config.q_lora_rank is not None self.config.q_lora_rank is not None
) )
cached_a_proj = {} if fuse_qkv_a_proj else None cached_a_proj = {} if fuse_qkv_a_proj else None
fuse_wk_and_weights_proj = is_nsa_indexer_wk_and_weights_proj_fused(
self.config, self.quant_config
)
cached_wk_and_weights_proj = {} if fuse_wk_and_weights_proj else None
if is_nextn: if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}" nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
...@@ -3584,6 +3602,53 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3584,6 +3602,53 @@ class DeepseekV2ForCausalLM(nn.Module):
) )
cached_a_proj.pop(q_a_proj_name) cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name) cached_a_proj.pop(kv_a_proj_name)
elif fuse_wk_and_weights_proj and (
"wk" in name or "weights_proj" in name
):
cached_wk_and_weights_proj[name] = loaded_weight
wk_name = (
name
if "wk" in name
else name.replace("weights_proj", "wk")
)
weights_proj_name = (
name
if "weights_proj" in name
else name.replace("wk", "weights_proj")
)
# When both wk and weights_proj has been cached, load the fused weight to parameter
if (
wk_name in cached_wk_and_weights_proj
and weights_proj_name in cached_wk_and_weights_proj
):
wk_weight = cached_wk_and_weights_proj[wk_name]
weights_proj_weight = cached_wk_and_weights_proj[
weights_proj_name
]
# todo dequantize wk for fp8
assert wk_weight.dtype == weights_proj_weight.dtype
fused_weight = torch.cat(
[wk_weight, weights_proj_weight], dim=0
)
param_name = (
name.replace("wk", "fused_wk_and_weights_proj")
if "wk" in name
else name.replace(
"weights_proj",
"fused_wk_and_weights_proj",
)
)
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
futures.append(
executor.submit(weight_loader, param, fused_weight)
)
cached_wk_and_weights_proj.pop(wk_name)
cached_wk_and_weights_proj.pop(weights_proj_name)
else: else:
if ( if (
"k_scale" in name or "v_scale" in name "k_scale" in name or "v_scale" in name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment