import torch import deformable_aggregation_ext.deformable_aggregation_ext as da bs = 6 cam = 6 num_feat = 89760 c = 256 scale = 4 group = 8 anchor = 1220 pts = 13 H, W = 44, 85 # H*W = 3740, 3740 * 6 * 4 = 89760 spatial_shape = torch.tensor([[[H, W]] * scale] * cam, dtype=torch.int32) # [cam, scale, 2] scale_start_index = torch.zeros((cam, scale), dtype=torch.int32) feat_area_per_map = H * W # 3740 for i in range(cam): for s in range(scale): idx = i * scale + s scale_start_index[i, s] = idx * feat_area_per_map # 验证最后一个偏移 + 面积 <= num_feat last_offset = scale_start_index[cam-1, scale-1] assert last_offset + feat_area_per_map == num_feat, "Total feature size mismatch!" #print("scale_start_index[0]:", scale_start_index[0]) # [0, 3740, 7480, 11220] #print("Total area:", last_offset + feat_area_per_map) # 其他张量保持不变 feature_maps = torch.rand(bs, num_feat, c).float().cuda() sampling_location = torch.rand(bs, anchor, pts, cam, 2).float().cuda() #print(f"size of feature_maps is {feature_maps.size()}") weights = torch.rand(bs, anchor, pts, cam, scale, group).float().cuda() # print(feature_maps.size()) # print(spatial_shape.size()) # print(scale_start_index.size()) # print(sampling_location.size()) # print(weights.size()) # print("#######################") # 调用算子 out = da.deformable_aggregation_forward( feature_maps, spatial_shape.cuda(), scale_start_index.cuda(), sampling_location, weights ) new_out = da.deformable_aggregation_forward_ref( feature_maps, spatial_shape.cuda(), scale_start_index.cuda(), sampling_location, weights ) # -------- 计算整体误差 -------- def compare_tensors(t1, t2, name="", rtol=1e-4, atol=1e-6): max_abs_diff = (t1 - t2).abs().max().item() is_close = torch.allclose(t1, t2, rtol=rtol, atol=atol) print(f"{name}: max_abs_diff={max_abs_diff}, allclose={is_close}") return max_abs_diff, is_close compare_tensors(out, new_out, "OUTPUT") print("finish")