"sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu" did not exist on "bb418ced802c6dbb6b0ae0d65218327129148769"
gather.py 1.71 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
4
from torch_scatter import segment_cpu, gather_cpu

rusty1s's avatar
rusty1s committed
5
if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
6
7
    from torch_scatter import gather_cuda, segment_cuda

rusty1s's avatar
linting  
rusty1s committed
8
9
10
11
12
13
14

def gat(is_cuda):
    return gather_cuda if is_cuda else gather_cpu


def seg(is_cuda):
    return segment_cuda if is_cuda else segment_cpu
rusty1s's avatar
rusty1s committed
15

rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
23
24

class GatherCOO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, index, out):
        if out is not None:
            ctx.mark_dirty(out)
        ctx.src_size = list(src.size())
        ctx.save_for_backward(index)

rusty1s's avatar
rusty1s committed
25
        return gat(src.is_cuda).gather_coo(src, index, out)
rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
32

    @staticmethod
    def backward(ctx, grad_out):
        (index, ), src_size = ctx.saved_tensors, ctx.src_size

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
33
            grad_src, _ = seg(grad_out.is_cuda).segment_coo(
rusty1s's avatar
rusty1s committed
34
                grad_out, index, grad_out.new_zeros(src_size), 'sum')
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
42
43
44
45
46

        return grad_src, None, None


class GatherCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, indptr, out):
        if out is not None:
            ctx.mark_dirty(out)
        ctx.src_size = list(src.size())
        ctx.save_for_backward(indptr)

rusty1s's avatar
rusty1s committed
47
        return gat(src.is_cuda).gather_csr(src, indptr, out)
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54

    @staticmethod
    def backward(ctx, grad_out):
        (indptr, ), src_size = ctx.saved_tensors, ctx.src_size

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
55
            grad_src, _ = seg(grad_out.is_cuda).segment_csr(
rusty1s's avatar
rusty1s committed
56
                grad_out, indptr, grad_out.new_empty(src_size), 'sum')
rusty1s's avatar
rusty1s committed
57
58

        return grad_src, None, None
rusty1s's avatar
rusty1s committed
59
60
61


def gather_coo(src, index, out=None):
rusty1s's avatar
rusty1s committed
62
    return GatherCOO.apply(src, index, out)
rusty1s's avatar
rusty1s committed
63
64
65


def gather_csr(src, indptr, out=None):
rusty1s's avatar
rusty1s committed
66
    return GatherCSR.apply(src, indptr, out)