test_forward.py 2.23 KB
Newer Older
Christian Sarofeen's avatar
Christian Sarofeen committed
1
2
import torch
import sys
3
import apex_C
Christian Sarofeen's avatar
Christian Sarofeen committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
from compare import compare
from norm import pt_norm, get_norm_shape


torch.manual_seed(2)
torch.cuda.manual_seed(2)
# torch.cuda.manual_seed_all(2)
torch.set_printoptions(precision=10)

sizes = [
    # (3,  512, 1024),
    # (3,  512, 1536),
    # (3,  768, 1536),
    # (3,  768, 2048),
    # (3, 1024, 2048),
    # (1, 1024, 4096),
    # (1, 2048, 8192),
22
23
24
    # (1, 4096, 4096), # this is not one of the fairseq sizes, just a reference benchmark.
    (4096, 4096, 1), # this is not one the fairseq sizes, just a reference benchmark.
    # (353, 55, 353), # this is not one of the fairseq sizes, just a reference benchmark.
Christian Sarofeen's avatar
Christian Sarofeen committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    ]

# rows = 3
# cols = 512
# fast = 1024
HALF = True
RAND = True
dim = 0
    

for rows, cols, fast in sizes:
    dims = rows, cols, fast
     
    print("\n\nTESTING dims = {}\n\n".format(dims))

    if RAND:
        pt_in = 1.*torch.cuda.FloatTensor(*dims).uniform_()
        g = torch.cuda.FloatTensor(*get_norm_shape(pt_in, dim)).uniform_()
    else:
        pt_in = torch.cuda.FloatTensor(*dims).fill_(1.)
        g = torch.cuda.FloatTensor(*get_norm_shape(pt_in, dim)).fill_(6.0)
    
    # per_col = torch.arange(1,cols+1).cuda()
    # print((rows*per_col*per_col).sqrt())
    # pt_in *= per_col
    
    cuda_out   =   torch.cuda.FloatTensor(*dims).fill_(0.)
    cuda_norms =   torch.cuda.FloatTensor(*get_norm_shape(pt_in, dim)).fill_(0.)
    
    # Save a copy of the input as float
    pt_in_fp32 = pt_in.clone()
    g_fp32     = g.clone()
    
    if HALF:
        pt_in    =    pt_in.half()
        g        =        g.half()
        cuda_out = cuda_out.half()
    
63
    apex_C.weight_norm_fwd(cuda_out, cuda_norms, pt_in, g, dim)
Christian Sarofeen's avatar
Christian Sarofeen committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    torch.cuda.synchronize()
    # quit()

    print("type(cuda_out) = {}\n".format(type(cuda_out)))
    
    rownorms      = pt_norm(pt_in, dim)
    rownorms_fp32 = pt_norm(pt_in_fp32, dim)
    
    print("rownorms_fp32:")
    print(rownorms_fp32)
    print("cuda_norms"    )
    print(cuda_norms   )
    
    # rownorms is broadcast; torch.div(pt_in, rownorms) and pt_in/rownorms work the same way
    pt_out         = pt_in*(g/rownorms)
    pt_out_control = pt_in_fp32*(g_fp32/rownorms_fp32)
    
    compare(cuda_out, pt_out, pt_out_control, rows)