utils.py 971 Bytes
Newer Older
1
2
import torch

3

4
5
def check_args(args):
    flag = sum([args.only_1st, args.only_2nd])
6
7
8
    assert (
        flag <= 1
    ), "no more than one selection from --only_1st and --only_2nd"
9
10
11
12
13
    if flag == 0:
        assert args.dim % 2 == 0, "embedding dimension must be an even number"
    if args.async_update:
        assert args.mix, "please use --async_update with --mix"

14

15
def sum_up_params(model):
16
    """Count the model parameters"""
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    n = []
    if model.fst:
        p = model.fst_u_embeddings.weight.cpu().data.numel()
        n.append(p)
        p = model.fst_state_sum_u.cpu().data.numel()
        n.append(p)
    if model.snd:
        p = model.snd_u_embeddings.weight.cpu().data.numel() * 2
        n.append(p)
        p = model.snd_state_sum_u.cpu().data.numel() * 2
        n.append(p)
    n.append(model.lookup_table.cpu().numel())
    try:
        n.append(model.index_emb_negu.cpu().numel() * 2)
    except:
        pass
33
    print("#params " + str(sum(n)))