fix_weight.py 751 Bytes
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch

# 指向你刚刚发给我的纯 LiDAR 官方权重
ckpt_path = 'pth/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d-2628f933.pth'
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = ckpt['state_dict']

fixed_count = 0
for key in list(state_dict.keys()):
    # 修复 3D 稀疏卷积维度 (16,3,3,3,16) -> (3,3,3,16,16)
    if 'pts_middle_encoder' in key and state_dict[key].dim() == 5:
        state_dict[key] = state_dict[key].permute(1, 2, 3, 4, 0).contiguous()
        fixed_count += 1

ckpt['state_dict'] = state_dict
fixed_path = ckpt_path.replace('.pth', '_fixed.pth')
torch.save(ckpt, fixed_path)
print(f'✅ 纯 LiDAR 权重修复完成!已保存至 {fixed_path},共处理 {fixed_count} 个层。')