"tools/vscode:/vscode.git/clone" did not exist on "6d281c311c37bc73b77dfadb7b7131d43e5dd733"
Unverified Commit b6dd4bcb authored by cao1zhg's avatar cao1zhg Committed by GitHub
Browse files

feat: update support for qwen3next model (#10466)

parent b2435be6
......@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)
......@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)
......
......@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
# Apply L2 normalization if enabled
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
......
......@@ -239,6 +239,7 @@ class Qwen3GatedDeltaNet(nn.Module):
self,
config: Qwen3NextConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
......@@ -278,6 +279,7 @@ class Qwen3GatedDeltaNet(nn.Module):
input_size=self.hidden_size,
output_size=projection_size_qkvz,
bias=False,
quant_config=quant_config,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
)
......@@ -285,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module):
input_size=self.hidden_size,
output_size=projection_size_ba,
bias=False,
quant_config=None,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
)
......@@ -336,6 +339,7 @@ class Qwen3GatedDeltaNet(nn.Module):
self.value_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
input_is_parallel=True,
reduce_results=False,
tp_rank=self.attn_tp_rank,
......@@ -493,7 +497,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.config = config
self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream)
self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, quant_config, alt_stream)
# Qwen3Next all layers are sparse and have no nextn now
self.is_layer_sparse = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment