mhc.py 6.44 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import warnings

import torch
import torch.nn as nn

from liger_kernel.transformers.functional import liger_mhc_coeffs
from liger_kernel.transformers.functional import liger_mhc_post_res
from liger_kernel.transformers.functional import liger_mhc_pre


class LigerMHC(nn.Module):
    """
    Manifold-Constrained Hyper-Connections (mHC) wrapper.

    Wraps an arbitrary layer ``F: [..., C] -> [..., C]`` with multiple residual
    streams, following the mHC architecture (arXiv:2512.24880). The input is a
    multi-stream tensor of shape ``[..., HC, C]`` where ``HC`` is the number of
    residual streams.

    The forward pass performs:

    1. **Coefficients** -- Compute data-dependent routing coefficients
       (``h_pre``, ``h_post``, ``h_res``) via a fused matmul + RMS
       normalization + Sinkhorn-Knopp iterations.
    2. **Pre-aggregate** -- ``x_in = sum_i h_pre[i] * x[i]``
       (shape: ``[..., C]``)
    3. **Layer** -- ``f_out = layer(x_in)``  (shape: ``[..., C]``)
    4. **Post + residual** --
       ``x_out[o] = sum_i h_res[o,i] * x[i] + h_post[o] * f_out``
       (shape: ``[..., HC, C]``)

    Args:
        layer: The module applied to the aggregated single-stream input.
            Must accept ``[..., C]`` and return ``[..., C]``. Common choices
            include ``nn.Linear``, attention layers, or MLP blocks.
        hc: Number of residual streams (called *n* in the original paper).
            Recommended range: [2, 16]. Larger values increase register
            pressure and Triton compile time.
        c: Per-stream channel dimension.
        tmax: Maximum Sinkhorn-Knopp iterations for doubly stochastic
            normalization of ``h_res``. Default: 20.
        rms_eps: Epsilon for RMS normalization of the projection.
            Default: 1e-6.
        pre_eps: Additive epsilon for ``h_pre`` after sigmoid. Default: 0.0.
        sinkhorn_eps: Epsilon added during Sinkhorn normalization.
            Default: 1e-6.
        post_mult: Scaling factor for ``h_post`` after sigmoid. Default: 2.0.
        phi_dtype: Dtype for the projection matrix ``phi``. Using float16 or
            bfloat16 enables Tensor Core acceleration. Default: torch.float16.
        allow_fp32: If True, accept FP32 input tensors. Note that FP32 mode
            does **not** use Tensor Cores and will be slower. Default: False.

    Learnable Parameters:
        - **phi** ``[HC*C, HC*HC + 2*HC]`` -- Projection matrix for computing
          routing coefficients from flattened stream tokens.
        - **b** ``[HC*HC + 2*HC]`` -- Bias for routing logits (float32).
        - **alpha_pre** (scalar) -- Scales pre-routing logits before sigmoid.
        - **alpha_post** (scalar) -- Scales post-routing logits before sigmoid.
        - **alpha_res** (scalar) -- Scales residual logits before Sinkhorn.

    Example::

        import torch
        import torch.nn as nn
        from liger_kernel.transformers import LigerMHC

        # Wrap a linear layer with 4 residual streams of dimension 256
        layer = nn.Linear(256, 256, bias=False, device="cuda", dtype=torch.bfloat16)
        mhc = LigerMHC(layer, hc=4, c=256, phi_dtype=torch.bfloat16).cuda()

        # Input: [batch, seq_len, num_streams, channels]
        x = torch.randn(2, 128, 4, 256, device="cuda", dtype=torch.bfloat16)
        out = mhc(x)  # shape: [2, 128, 4, 256]

        # In a transformer block (pseudocode):
        # x = mhc_attn(x)   # attention wrapped in LigerMHC
        # x = mhc_ffn(x)    # FFN wrapped in LigerMHC
    """

    def __init__(
        self,
        layer: nn.Module,
        *,
        hc: int,
        c: int,
        tmax: int = 20,
        rms_eps: float = 1e-6,
        pre_eps: float = 0.0,
        sinkhorn_eps: float = 1e-6,
        post_mult: float = 2.0,
        phi_dtype: torch.dtype = torch.float16,
        allow_fp32: bool = False,
    ):
        super().__init__()
        self.layer = layer
        # hc: number of residual streams (n in the paper)
        self.hc = int(hc)
        self.c = int(c)

        if hc > 16:
            warnings.warn(
                f"hc={hc} exceeds recommended range [2, 16]. "
                "Large values may cause register pressure and increased compile time.",
                stacklevel=2,
            )
        self.tmax = int(tmax)
        self.rms_eps = float(rms_eps)
        self.pre_eps = float(pre_eps)
        self.sinkhorn_eps = float(sinkhorn_eps)
        self.post_mult = float(post_mult)
        self.allow_fp32 = bool(allow_fp32)

        m = hc * hc + 2 * hc
        k = hc * c

        try:
            layer_device = next(self.layer.parameters()).device
        except StopIteration:
            layer_device = torch.device("cpu")

        # Note: for best speed, keep phi in BF16/FP16 to enable tensor-core matmul in Triton.
        self.phi = nn.Parameter(torch.randn(k, m, dtype=phi_dtype, device=layer_device) * 0.02)
        self.b = nn.Parameter(torch.zeros(m, dtype=torch.float32, device=layer_device))
        self.alpha_pre = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=layer_device))
        self.alpha_post = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=layer_device))
        self.alpha_res = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=layer_device))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [..., HC, C] (BF16/FP16 recommended; FP32 allowed if allow_fp32=True)
        returns: [..., HC, C]
        """
        if x.shape[-2] != self.hc or x.shape[-1] != self.c:
            raise ValueError(f"Expected x.shape[-2:]=[{self.hc}, {self.c}], got {list(x.shape[-2:])}")

        h_pre, h_post, h_res = liger_mhc_coeffs(
            x,
            self.phi,
            self.b,
            self.alpha_pre,
            self.alpha_post,
            self.alpha_res,
            allow_fp32=self.allow_fp32,
            tmax=self.tmax,
            rms_eps=self.rms_eps,
            pre_eps=self.pre_eps,
            sinkhorn_eps=self.sinkhorn_eps,
            post_mult=self.post_mult,
        )
        x_in = liger_mhc_pre(x, h_pre)  # [..., C]
        layer_dtype = x_in.dtype
        for param in self.layer.parameters(recurse=True):
            layer_dtype = param.dtype
            break
        if x_in.dtype != layer_dtype:
            x_in = x_in.to(layer_dtype)
        f_out = self.layer(x_in)  # [..., C]
        x_out = liger_mhc_post_res(x, f_out, h_post, h_res)  # [..., HC, C]
        return x_out

    def extra_repr(self) -> str:
        return f"hc={self.hc}, c={self.c}, tmax={self.tmax}"