"vscode:/vscode.git/clone" did not exist on "a4b8d1f89a73aab6dda01dae584c4e27aa4d632a"
custom_op.py 1.1 KB
Newer Older
1
2
from typing import Optional

3
4
5
import torch
from torch import nn

6
7
8
9
from sglang.srt.utils import is_cuda, is_hip

_is_cuda = is_cuda()
_is_hip = is_hip()
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class CustomOp(nn.Module):
    def __init__(self):
        super().__init__()
        self._forward_method = self.dispatch_forward()

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def forward_native(self, *args, **kwargs):
        raise NotImplementedError

    def forward_cuda(self, *args, **kwargs):
        raise NotImplementedError

    def forward_hip(self, *args, **kwargs):
27
        return self.forward_cuda(*args, **kwargs)
28
29
30
31
32
33
34
35
36
37
38
39
40

    def forward_xpu(self, *args, **kwargs):
        return self.forward_native(*args, **kwargs)

    def forward_hpu(self, *args, **kwargs):
        return self.forward_native(*args, **kwargs)

    def forward_cpu(self, *args, **kwargs):
        return self.forward_native(*args, **kwargs)

    def dispatch_forward(self):
        if _is_cuda:
            return self.forward_cuda
41
        elif _is_hip:
42
43
44
            return self.forward_hip
        else:
            return self.forward_native