Unverified Commit b6dd4bcb authored by cao1zhg's avatar cao1zhg Committed by GitHub
Browse files

feat: update support for qwen3next model (#10466)

parent b2435be6
...@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_g = tl.load(p_g).to(tl.float32) b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL: if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale b_q = b_q * scale
# [BK, BV] # [BK, BV]
b_h *= exp(b_g) b_h *= exp(b_g)
...@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel( ...@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
b_g = tl.load(p_g).to(tl.float32) b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL: if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale b_q = b_q * scale
# [BK, BV] # [BK, BV]
b_h *= exp(b_g) b_h *= exp(b_g)
......
...@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel( ...@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
# Apply L2 normalization if enabled # Apply L2 normalization if enabled
if USE_QK_L2NORM_IN_KERNEL: if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale b_q = b_q * scale
......
...@@ -239,6 +239,7 @@ class Qwen3GatedDeltaNet(nn.Module): ...@@ -239,6 +239,7 @@ class Qwen3GatedDeltaNet(nn.Module):
self, self,
config: Qwen3NextConfig, config: Qwen3NextConfig,
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None, alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -278,6 +279,7 @@ class Qwen3GatedDeltaNet(nn.Module): ...@@ -278,6 +279,7 @@ class Qwen3GatedDeltaNet(nn.Module):
input_size=self.hidden_size, input_size=self.hidden_size,
output_size=projection_size_qkvz, output_size=projection_size_qkvz,
bias=False, bias=False,
quant_config=quant_config,
tp_rank=self.attn_tp_rank, tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size, tp_size=self.attn_tp_size,
) )
...@@ -285,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module): ...@@ -285,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module):
input_size=self.hidden_size, input_size=self.hidden_size,
output_size=projection_size_ba, output_size=projection_size_ba,
bias=False, bias=False,
quant_config=None,
tp_rank=self.attn_tp_rank, tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size, tp_size=self.attn_tp_size,
) )
...@@ -336,6 +339,7 @@ class Qwen3GatedDeltaNet(nn.Module): ...@@ -336,6 +339,7 @@ class Qwen3GatedDeltaNet(nn.Module):
self.value_dim, self.value_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config,
input_is_parallel=True, input_is_parallel=True,
reduce_results=False, reduce_results=False,
tp_rank=self.attn_tp_rank, tp_rank=self.attn_tp_rank,
...@@ -493,7 +497,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module): ...@@ -493,7 +497,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream) self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, quant_config, alt_stream)
# Qwen3Next all layers are sparse and have no nextn now # Qwen3Next all layers are sparse and have no nextn now
self.is_layer_sparse = True self.is_layer_sparse = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment