Unverified Commit d77557b0 authored by Grapymage's avatar Grapymage Committed by GitHub
Browse files

[Fix] Fix index error when using multi-samplers strategy (#2094)

* fix index error when using multi-samplers strategy

* After every loop, change the last_fps_end_index to fps_sample_range
parent c67ab9a3
...@@ -104,7 +104,6 @@ class PointsSampler(nn.Module): ...@@ -104,7 +104,6 @@ class PointsSampler(nn.Module):
""" """
indices = [] indices = []
last_fps_end_index = 0 last_fps_end_index = 0
for fps_sample_range, sampler, npoint in zip( for fps_sample_range, sampler, npoint in zip(
self.fps_sample_range_list, self.samplers, self.num_point): self.fps_sample_range_list, self.samplers, self.num_point):
assert fps_sample_range < points_xyz.shape[1] assert fps_sample_range < points_xyz.shape[1]
...@@ -116,8 +115,8 @@ class PointsSampler(nn.Module): ...@@ -116,8 +115,8 @@ class PointsSampler(nn.Module):
else: else:
sample_features = None sample_features = None
else: else:
sample_points_xyz = \ sample_points_xyz = points_xyz[:, last_fps_end_index:
points_xyz[:, last_fps_end_index:fps_sample_range] fps_sample_range]
if features is not None: if features is not None:
sample_features = features[:, :, last_fps_end_index: sample_features = features[:, :, last_fps_end_index:
fps_sample_range] fps_sample_range]
...@@ -128,7 +127,7 @@ class PointsSampler(nn.Module): ...@@ -128,7 +127,7 @@ class PointsSampler(nn.Module):
npoint) npoint)
indices.append(fps_idx + last_fps_end_index) indices.append(fps_idx + last_fps_end_index)
last_fps_end_index += fps_sample_range last_fps_end_index = fps_sample_range
indices = torch.cat(indices, dim=1) indices = torch.cat(indices, dim=1)
return indices return indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment