Commit e0a0db1e authored by Ruilong Li's avatar Ruilong Li
Browse files

get rid of scatter max

parent b73fa67b
......@@ -2,7 +2,8 @@ from typing import Callable, List, Tuple, Union
import torch
from torch import nn
from torch_scatter import scatter_max
# from torch_scatter import scatter_max
def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"):
......@@ -146,7 +147,8 @@ class OccupancyField(nn.Module):
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
x = x * (bb_max - bb_min) + bb_min
tmp_occ = self.occ_eval_fn(x).squeeze(-1)
tmp_occ_grid, _ = scatter_max(tmp_occ, indices, dim=0, out=tmp_occ_grid)
tmp_occ_grid[indices] = tmp_occ
# tmp_occ_grid, _ = scatter_max(tmp_occ, indices, dim=0, out=tmp_occ_grid)
# ema update
ema_mask = (self.occ_grid >= 0) & (tmp_occ_grid >= 0)
......
ninja
pybind11
torch
# torch-scatter -f https://data.pyg.org/whl/torch-1.12.0+cu102.html
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment