"megatron/legacy/model/transformer.py" did not exist on "2cc3dac719ebc2456a7857ade2f430530dc31ab3"
compat.py 1.36 KB
Newer Older
JR_ZZU's avatar
JR_ZZU committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch

# True for post-0.4, when Variables/Tensors merged.
def variable_is_tensor():
    v = torch.autograd.Variable()
    return isinstance(v, torch.Tensor)

def tensor_is_variable():
    x = torch.Tensor()
    return type(x) == torch.autograd.Variable

# False for post-0.4
def tensor_is_float_tensor():
    x = torch.Tensor()
    return type(x) == torch.FloatTensor

# Akin to `torch.is_tensor`, but returns True for Variable
# objects in pre-0.4.
def is_tensor_like(x):
    return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)

# Wraps `torch.is_floating_point` if present, otherwise checks
# the suffix of `x.type()`.
def is_floating_point(x):
    if hasattr(torch, 'is_floating_point'):
        return torch.is_floating_point(x)
    try:
        torch_type = x.type()
        return torch_type.endswith('FloatTensor') or \
            torch_type.endswith('HalfTensor') or \
            torch_type.endswith('DoubleTensor')
    except AttributeError:
        return False

def scalar_python_val(x):
    if hasattr(x, 'item'):
        return x.item()
    else:
        if isinstance(x, torch.autograd.Variable):
            return x.data[0]
        else:
            return x[0]

# Accounts for the possibility that some ops may be removed from a namespace.
def filter_attrs(module, attrs):
    return list(attrname for attrname in attrs if hasattr(module, attrname))