components.py 110 Bytes
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
import torch.nn.functional as F

def swiglu(x, y):
    return F.silu(x.float(), inplace=False).to(x.dtype) * y