forward.py 1.96 KB
Newer Older
change3n8's avatar
init  
change3n8 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import deformable_aggregation_ext.deformable_aggregation_ext as da

bs = 6
cam = 6
num_feat = 89760
c = 256
scale = 4
group = 8
anchor = 1220
pts = 13

H, W = 44, 85  # H*W = 3740, 3740 * 6 * 4 = 89760
spatial_shape = torch.tensor([[[H, W]] * scale] * cam, dtype=torch.int32)  # [cam, scale, 2]

scale_start_index = torch.zeros((cam, scale), dtype=torch.int32)
feat_area_per_map = H * W  # 3740
for i in range(cam):
    for s in range(scale):
        idx = i * scale + s
        scale_start_index[i, s] = idx * feat_area_per_map

# 验证最后一个偏移 + 面积 <= num_feat
last_offset = scale_start_index[cam-1, scale-1]
assert last_offset + feat_area_per_map == num_feat, "Total feature size mismatch!"

#print("scale_start_index[0]:", scale_start_index[0])  # [0, 3740, 7480, 11220]
#print("Total area:", last_offset + feat_area_per_map)

# 其他张量保持不变
feature_maps = torch.rand(bs, num_feat, c).float().cuda()
sampling_location = torch.rand(bs, anchor, pts, cam, 2).float().cuda()

#print(f"size of feature_maps is {feature_maps.size()}")

weights = torch.rand(bs, anchor, pts, cam, scale, group).float().cuda()


# print(feature_maps.size())
# print(spatial_shape.size())
# print(scale_start_index.size())
# print(sampling_location.size())
# print(weights.size())
# print("#######################")

# 调用算子
out = da.deformable_aggregation_forward(
    feature_maps, spatial_shape.cuda(), scale_start_index.cuda(),
    sampling_location, weights
)

new_out = da.deformable_aggregation_forward_ref(
    feature_maps, spatial_shape.cuda(), scale_start_index.cuda(),
    sampling_location, weights
)

# -------- 计算整体误差 --------

def compare_tensors(t1, t2, name="", rtol=1e-4, atol=1e-6):
    max_abs_diff = (t1 - t2).abs().max().item()
    is_close = torch.allclose(t1, t2, rtol=rtol, atol=atol)
    print(f"{name}: max_abs_diff={max_abs_diff}, allclose={is_close}")
    return max_abs_diff, is_close

compare_tensors(out, new_out, "OUTPUT")

print("finish")