test_encode_reg_target.py 3.63 KB
Newer Older
change3n8's avatar
init  
change3n8 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import encode_reg_target_ext
import time

# -----------------------------
# reference implementation
# -----------------------------
def encode_reg_target_ref(box_target, device=None):
    outputs = []

    for box in box_target:
        output = torch.cat(
            [
                box[..., [X, Y, Z]],
                box[..., [W, L, H]].log(),
                torch.sin(box[..., YAW]).unsqueeze(-1),
                torch.cos(box[..., YAW]).unsqueeze(-1),
                box[..., YAW + 1:],
            ],
            dim=-1,
        )

        if device is not None:
            output = output.to(device=device)

        outputs.append(output)

    return outputs


# -----------------------------
# Optimized HIP/C++ implementation
# -----------------------------
def encode_reg_target_optimized(box_target_list, device=None):
    if len(box_target_list) == 0:
        return []

    outputs = []

    dev = device if device is not None else box_target_list[0].device
    box_target_list = [t.to(dev) if t.device != dev else t for t in box_target_list]

    for box in box_target_list:
        N, D = box.shape
        out = encode_reg_target_ext.encode_reg_t(box, N)
        outputs.append(out)

    return outputs


# -----------------------------
# Main benchmark
# -----------------------------
if __name__ == "__main__":

    # index definition
    X, Y, Z, W, L, H, SIN_YAW, COS_YAW, VX, VY, VZ = list(range(11))
    YAW = 6

    device = "cuda:0" 

    box_target_len = 100
    N_dim = 900
    D_dim = 10

    m_dims = [
        40, 12, 25, 111, 45, 84, 19, 14, 52, 13,
        20, 19, 28, 28, 6, 62, 26, 56, 13, 33
    ]

    box_target = [
        torch.rand(m, D_dim, dtype=torch.float32, device=device)
        for m in m_dims
    ]

    num_warmup = 10
    num_runs = 100

    # -----------------------------
    # Warmup
    # -----------------------------
    for _ in range(num_warmup):
        encode_reg_target_ref(box_target, device)
    for _ in range(num_warmup):
        encode_reg_target_optimized(box_target, device)

    torch.cuda.synchronize()

    # -----------------------------
    # Measure reference
    # -----------------------------
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    for _ in range(num_runs):
        out_ref = encode_reg_target_ref(box_target, device)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms_ref = start_event.elapsed_time(end_event)
    avg_time_ms_ref = elapsed_time_ms_ref / num_runs

    print(f"reference avg time: {avg_time_ms_ref:.4f} ms")

    # -----------------------------
    # Measure optimized
    # -----------------------------
    start_event.record()
    for _ in range(num_runs):
        out_optimized = encode_reg_target_optimized(box_target, device)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms_opt = start_event.elapsed_time(end_event)
    avg_time_ms_opt = elapsed_time_ms_opt / num_runs

    print(f"Optimized kernel avg time: {avg_time_ms_opt:.4f} ms")

    # -----------------------------
    # Accuracy check
    # -----------------------------
    all_close = True
    max_diff = 0.0
    mean_diff = 0.0

    for ref, opt in zip(out_ref, out_optimized):
        diff = torch.abs(ref - opt)
        if not torch.allclose(ref, opt, atol=1e-5):
            all_close = False
        max_diff = max(max_diff, diff.max().item())
        mean_diff += diff.mean().item()

    mean_diff /= len(out_ref)

    print("Accuracy check:")
    print("Pass" if all_close else "Fail")
    print(f"Max diff:  {max_diff}")
    print(f"Mean diff: {mean_diff}")
    print("Test end.")