Unverified Commit c70fafeb authored by ShawnHu's avatar ShawnHu Committed by GitHub
Browse files

Add type hints in ops/assign_score_withk.py (#2023)

parent de90c7a2
from typing import Tuple
import torch
from torch.autograd import Function from torch.autograd import Function
from ..utils import ext_loader from ..utils import ext_loader
...@@ -27,11 +30,11 @@ class AssignScoreWithK(Function): ...@@ -27,11 +30,11 @@ class AssignScoreWithK(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
scores, scores: torch.Tensor,
point_features, point_features: torch.Tensor,
center_features, center_features: torch.Tensor,
knn_idx, knn_idx: torch.Tensor,
aggregate='sum'): aggregate: str = 'sum') -> torch.Tensor:
""" """
Args: Args:
scores (torch.Tensor): (B, npoint, K, M), predicted scores to scores (torch.Tensor): (B, npoint, K, M), predicted scores to
...@@ -78,7 +81,9 @@ class AssignScoreWithK(Function): ...@@ -78,7 +81,9 @@ class AssignScoreWithK(Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(
ctx, grad_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
""" """
Args: Args:
grad_out (torch.Tensor): (B, out_dim, npoint, K) grad_out (torch.Tensor): (B, out_dim, npoint, K)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment