import torch def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): if dim < 0: dim = other.dim() + dim if src.dim() == 1: for _ in range(0, dim): src = src.unsqueeze(0) for _ in range(src.dim(), other.dim()): src = src.unsqueeze(-1) src = src.expand(other.size()) return src