import argparse import time import numpy as np import torch import torch.nn as nn import spconv as spconv_root import spconv.pytorch as spconv def parse_args(): parser = argparse.ArgumentParser(description="Run a small spconv performance smoke test.") parser.add_argument("--dtype", choices=("fp32", "fp16"), default="fp32") parser.add_argument("--num-runs", type=int, default=100) parser.add_argument("--warmup-runs", type=int, default=10) parser.add_argument("--num-points", type=int, default=5000) parser.add_argument("--skip-dense", action="store_true") return parser.parse_args() def torch_dtype(name): if name == "fp16": return torch.float16 return torch.float32 def synchronize(device): if device == "cuda": torch.cuda.synchronize() class SimpleSparseConvNet(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = spconv.SubMConv3d( in_channels, 16, kernel_size=3, padding=1, bias=False, indice_key="subm1" ) self.conv2 = spconv.SubMConv3d( 16, out_channels, kernel_size=3, padding=1, bias=False, indice_key="subm2" ) self.bn1 = nn.BatchNorm1d(16) self.bn2 = nn.BatchNorm1d(out_channels) self.relu = nn.ReLU() def forward(self, x): out = self.conv1(x) out = out.replace_feature(self.relu(self.bn1(out.features))) out = self.conv2(out) out = out.replace_feature(self.relu(self.bn2(out.features))) return out def create_sparse_input(batch_size, num_points, spatial_shape, in_channels, device, dtype): coors = torch.randint( 0, spatial_shape[0], (num_points, 3), device=device, dtype=torch.int32 ) coors = torch.cat([ torch.zeros(num_points, 1, dtype=torch.int32, device=device), coors, ], dim=1) features = torch.randn(num_points, in_channels, device=device, dtype=dtype) return spconv.SparseConvTensor( indices=coors, features=features, spatial_shape=spatial_shape, batch_size=batch_size, ) def run_sparse_forward(model, x, device, warmup_runs, num_runs): model.eval() with torch.no_grad(): for _ in range(warmup_runs): _ = model(x) synchronize(device) start = time.time() for _ in range(num_runs): _ = model(x) synchronize(device) return (time.time() - start) / num_runs def run_dense_forward(channels, spatial_shape, kernel_size, device, dtype, warmup_runs, num_runs): in_c, out_c = channels dense_conv = nn.Conv3d(in_c, out_c, kernel_size, padding=1).to(device=device, dtype=dtype) dummy_input = torch.randn(1, in_c, *spatial_shape, device=device, dtype=dtype) dense_conv.eval() with torch.no_grad(): for _ in range(warmup_runs): _ = dense_conv(dummy_input) synchronize(device) start = time.time() for _ in range(num_runs): _ = dense_conv(dummy_input) synchronize(device) return (time.time() - start) / num_runs def main(): args = parse_args() dtype = torch_dtype(args.dtype) print("spconv version:", getattr(spconv_root, "__version__", "unknown")) print("CUDA available:", torch.cuda.is_available()) if torch.cuda.is_available(): print("CUDA device:", torch.cuda.get_device_name(0)) batch_size = 1 in_channels = 4 out_channels = 32 spatial_shape = (64, 64, 64) device = "cuda" if torch.cuda.is_available() else "cpu" print( f"\n测试配置: batch_size={batch_size}, in_channels={in_channels}, " f"num_points={args.num_points}, spatial_shape={spatial_shape}, " f"device={device}, dtype={args.dtype}, num_runs={args.num_runs}" ) x = create_sparse_input( batch_size, args.num_points, spatial_shape, in_channels, device, dtype ) model = SimpleSparseConvNet(in_channels, out_channels).to(device=device, dtype=dtype) model.eval() print("\n模型结构:") print(model) print("\n--- 运行稀疏卷积前向传播 ---") synchronize(device) start_time = time.time() with torch.no_grad(): output = model(x) synchronize(device) print(f"前向传播耗时: {(time.time() - start_time) * 1000:.2f} ms") print(f"输入非零特征数: {x.features.shape[0]}") print(f"输出非零特征数: {output.features.shape[0]}") print(f"输出 shape (features): {output.features.shape}") print(f"输出 dtype: {output.features.dtype}") print("\n--- 效率对比测试 (批量推理) ---") dense_param_count = (in_channels * out_channels * 27) + out_channels sparse_param_count = sum(p.numel() for p in model.parameters()) sparsity_ratio = np.prod(spatial_shape) / args.num_points print(f"稠密卷积核参数量: {dense_param_count:,}") print(f"稀疏卷积网络总参数量: {sparse_param_count:,}") print(f"稀疏率 (总格子 / 非零点数): {sparsity_ratio:.1f}") sparse_avg_time = run_sparse_forward( model, x, device, args.warmup_runs, args.num_runs ) print(f"\n稀疏卷积平均耗时: {sparse_avg_time * 1000:.2f} ms") if not args.skip_dense: try: dense_avg_time = run_dense_forward( (in_channels, out_channels), spatial_shape, 3, device, dtype, args.warmup_runs, args.num_runs, ) print(f"稠密 卷积平均耗时: {dense_avg_time * 1000:.2f} ms") print(f"速度提升: {dense_avg_time / sparse_avg_time:.2f}x") except Exception as exc: print(f"稠密卷积测试失败: {exc}") print("(对于较大 spatial_shape 或 fp16 路径,稠密卷积可能不适合作为对比。)") print("\n--- 完整性验证 ---") assert output.features.shape[0] == args.num_points, "SubMConv 输出非零点数应与输入保持一致" print("✅ SubMConv 保持稀疏模式,输出非零点数量与输入一致。") try: dense_output = output.dense() print(f"✅ dense() 转换成功,shape: {dense_output.shape}, dtype: {dense_output.dtype}") except RuntimeError as exc: print(f"dense() 转换失败,通常是显存不足或当前 dtype 路径限制: {exc}") print("\n所有测试完成!") if __name__ == "__main__": main()