"official/vision/detection/__init__.py" did not exist on "2c5c3f3534c8109fed69b6523dc1fe38950596d3"
utils.py 1.88 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import random
import torch
from functools import wraps
import torch.utils._device


def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
    def __init__(self, device=None):
        self.device = device

    def __torch_function__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        if getattr(func, '__module__', None) == 'torch.nn.init':
            if 'tensor' in kwargs:
                return kwargs['tensor']
            else:
                return args[0]
        if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
            kwargs['device'] = self.device
        return func(*args, **kwargs)
    

def with_empty_init(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        with EmptyInitOnDevice('cpu'):
            return func(*args, **kwargs)
    return wrapper


    
    
def culens2mask(
    cu_seqlens=None,
    cu_seqlens_kv=None,
    max_seqlen=None,
    max_seqlen_kv=None,
    is_causal=False
):
    assert len(cu_seqlens) == len(cu_seqlens_kv); "q k v should have same bsz..."
    bsz = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:]-cu_seqlens[:-1]
    seqlens_kv = cu_seqlens_kv[1:]-cu_seqlens_kv[:-1]
    
    attn_mask = torch.zeros(bsz, max_seqlen, max_seqlen_kv, dtype=torch.bool)
    for i, (seq_len, seq_len_kv) in enumerate(zip(seqlens, seqlens_kv)):
        if is_causal:
            attn_mask[i, :seq_len, :seq_len_kv] = torch.triu(torch.ones(seq_len, seq_len_kv), diagonal=1).bool()
        else:
            attn_mask[i, :seq_len, :seq_len_kv] = torch.ones([seq_len, seq_len_kv], dtype=torch.bool)

    return attn_mask