Unverified Commit c70fafeb authored by ShawnHu's avatar ShawnHu Committed by GitHub
Browse files

Add type hints in ops/assign_score_withk.py (#2023)

parent de90c7a2
from typing import Tuple
import torch
from torch.autograd import Function
from ..utils import ext_loader
......@@ -27,11 +30,11 @@ class AssignScoreWithK(Function):
@staticmethod
def forward(ctx,
scores,
point_features,
center_features,
knn_idx,
aggregate='sum'):
scores: torch.Tensor,
point_features: torch.Tensor,
center_features: torch.Tensor,
knn_idx: torch.Tensor,
aggregate: str = 'sum') -> torch.Tensor:
"""
Args:
scores (torch.Tensor): (B, npoint, K, M), predicted scores to
......@@ -78,7 +81,9 @@ class AssignScoreWithK(Function):
return output
@staticmethod
def backward(ctx, grad_out):
def backward(
ctx, grad_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
"""
Args:
grad_out (torch.Tensor): (B, out_dim, npoint, K)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment