import torch def batch_euclidean_dist(x, y): """ Args: x: pytorch Variable, with shape [Batch size, Local part, Feature channel] y: pytorch Variable, with shape [Batch size, Local part, Feature channel] Returns: dist: pytorch Variable, with shape [Batch size, Local part, Local part] """ assert len(x.size()) == 3 assert len(y.size()) == 3 assert x.size(0) == y.size(0) assert x.size(-1) == y.size(-1) N, m, d = x.size() N, n, d = y.size() # shape [N, m, n] xx = torch.pow(x, 2).sum(-1, keepdim=True).expand(N, m, n) yy = torch.pow(y, 2).sum(-1, keepdim=True).expand(N, n, m).permute(0, 2, 1) dist = xx + yy dist.baddbmm_(1, -2, x, y.permute(0, 2, 1)) dist = dist.clamp(min=1e-12).sqrt() # for numerical stability return dist def shortest_dist(dist_mat): """Parallel version. Args: dist_mat: pytorch Variable, available shape: 1) [m, n] 2) [m, n, N], N is batch size 3) [m, n, *], * can be arbitrary additional dimensions Returns: dist: three cases corresponding to `dist_mat`: 1) scalar 2) pytorch Variable, with shape [N] 3) pytorch Variable, with shape [*] """ m, n = dist_mat.size()[:2] # Just offering some reference for accessing intermediate distance. dist = [[0 for _ in range(n)] for _ in range(m)] for i in range(m): for j in range(n): if (i == 0) and (j == 0): dist[i][j] = dist_mat[i, j] elif (i == 0) and (j > 0): dist[i][j] = dist[i][j - 1] + dist_mat[i, j] elif (i > 0) and (j == 0): dist[i][j] = dist[i - 1][j] + dist_mat[i, j] else: dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j] dist = dist[-1][-1] return dist def hard_example_mining(dist_mat, labels, return_inds=False): """For each anchor, find the hardest positive and negative sample. Args: dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] labels: pytorch LongTensor, with shape [N] return_inds: whether to return the indices. Save time if `False`(?) Returns: dist_ap: pytorch Variable, distance(anchor, positive); shape [N] dist_an: pytorch Variable, distance(anchor, negative); shape [N] p_inds: pytorch LongTensor, with shape [N]; indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 n_inds: pytorch LongTensor, with shape [N]; indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 NOTE: Only consider the case in which all labels have same num of samples, thus we can cope with all anchors in parallel. """ assert len(dist_mat.size()) == 2 assert dist_mat.size(0) == dist_mat.size(1) N = dist_mat.size(0) # shape [N, N] is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) # `dist_ap` means distance(anchor, positive) # both `dist_ap` and `relative_p_inds` with shape [N, 1] dist_ap, relative_p_inds = torch.max( dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) # `dist_an` means distance(anchor, negative) # both `dist_an` and `relative_n_inds` with shape [N, 1] dist_an, relative_n_inds = torch.min( dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) # shape [N] dist_ap = dist_ap.squeeze(1) dist_an = dist_an.squeeze(1) if return_inds: # shape [N, N] ind = (labels.new().resize_as_(labels) .copy_(torch.arange(0, N).long()) .unsqueeze( 0).expand(N, N)) # shape [N, 1] p_inds = torch.gather( ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) n_inds = torch.gather( ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) # shape [N] p_inds = p_inds.squeeze(1) n_inds = n_inds.squeeze(1) return dist_ap, dist_an, p_inds, n_inds return dist_ap, dist_an def euclidean_dist(x, y): """ Args: x: pytorch Variable, with shape [m, d] y: pytorch Variable, with shape [n, d] Returns: dist: pytorch Variable, with shape [m, n] """ m, n = x.size(0), y.size(0) xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() dist = xx + yy dist.addmm_(1, -2, x, y.t()) dist = dist.clamp(min=1e-12).sqrt() # for numerical stability return dist def batch_local_dist(x, y): """ Args: x: pytorch Variable, with shape [N, m, d] y: pytorch Variable, with shape [N, n, d] Returns: dist: pytorch Variable, with shape [N] """ assert len(x.size()) == 3 assert len(y.size()) == 3 assert x.size(0) == y.size(0) assert x.size(-1) == y.size(-1) # shape [N, m, n] dist_mat = batch_euclidean_dist(x, y) dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.) # shape [N] dist = shortest_dist(dist_mat.permute(1, 2, 0)) return dist if __name__ == '__main__': x = torch.randn(32,2048) y = torch.randn(32,2048) dist_mat = euclidean_dist(x,y) dist_ap, dist_an, p_inds, n_inds = hard_example_mining(dist_mat,return_inds=True) from IPython import embed embed()