linear.py 1.45 KB
Newer Older
weishb's avatar
weishb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import VarLenTensor

__all__ = [
    'SparseLinear',
    'ROCM_SAFE_CHUNK',
    'rocm_safe_linear',
]

# ROCm GFX1201 (RX 9070 XT) bug workaround:
# hipBLASLt and rocBLAS GEMM kernels corrupt memory (→ NaN) when N > ~800k
# for shapes like [N, K] @ [K, M] with small K/M.  Chunking keeps each
# dispatch below the confirmed-safe threshold of 524288 rows.
ROCM_SAFE_CHUNK = 524_288


def rocm_safe_linear(feats: torch.Tensor, weight: torch.Tensor, bias=None) -> torch.Tensor:
    """F.linear with ROCm large-N chunking workaround."""
    N = feats.shape[0]
    if N <= ROCM_SAFE_CHUNK:
        return F.linear(feats, weight, bias)
    out = torch.empty(N, weight.shape[0], device=feats.device, dtype=feats.dtype)
    for s in range(0, N, ROCM_SAFE_CHUNK):
        e = min(s + ROCM_SAFE_CHUNK, N)
        out[s:e] = F.linear(feats[s:e], weight, bias)
    return out


class SparseLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(SparseLinear, self).__init__(in_features, out_features, bias)

    #def forward(self, input: VarLenTensor) -> VarLenTensor:
    #    return input.replace(super().forward(input.feats))

    def forward(self, input):
        feats = input.feats if hasattr(input, 'feats') else input
        out = rocm_safe_linear(feats, self.weight, self.bias)
        if hasattr(input, 'replace'):
            return input.replace(out)
        return out