"lib/llm/src/kv_router/scheduler.rs" did not exist on "b90535aae4190ca4e49cc8249be29e62957d3b2e"
test_conv_gradfix.py 1.42 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial

import pytest
import torch
import torch.nn as nn
from torch.autograd import gradgradcheck

from mmgen.ops import conv2d, conv_transpose2d


class TestCond2d:

    @classmethod
    def setup_class(cls):
        cls.input = torch.randn((1, 3, 32, 32))
        cls.weight = nn.Parameter(torch.randn(1, 3, 3, 3))

    @pytest.mark.skipif(
        not torch.cuda.is_available()
        or not hasattr(torch.backends.cudnn, 'allow_tf32'),
        reason='requires cuda')
    def test_conv2d_cuda(self):
        x = self.input.cuda()
        weight = self.weight.cuda()
        res = conv2d(x, weight, None, 1, 1)
        assert res.shape == (1, 1, 32, 32)
        gradgradcheck(partial(conv2d, weight=weight, padding=1, stride=1), x)


class TestCond2dTansposed:

    @classmethod
    def setup_class(cls):
        cls.input = torch.randn((1, 3, 32, 32))
        cls.weight = nn.Parameter(torch.randn(3, 1, 3, 3))

    @pytest.mark.skipif(
        not torch.cuda.is_available()
        or not hasattr(torch.backends.cudnn, 'allow_tf32'),
        reason='requires cuda')
    def test_conv2d_transposed_cuda(self):
        x = self.input.cuda()
        weight = self.weight.cuda()
        res = conv_transpose2d(x, weight, None, 1, 1)
        assert res.shape == (1, 1, 32, 32)
        gradgradcheck(
            partial(conv_transpose2d, weight=weight, padding=1, stride=1), x)