#include #include at::Tensor Aggregate_Forward_CPU( const at::Tensor A, const at::Tensor X, const at::Tensor C) { auto E = (A.unsqueeze(3) * (X.unsqueeze(2).expand({X.size(0), X.size(1), C.size(0), C.size(1)}) - C.unsqueeze(0).unsqueeze(0))).sum(1); return E; } std::vector Aggregate_Backward_CPU( const at::Tensor GE, const at::Tensor A, const at::Tensor X, const at::Tensor C) { auto gradA = (GE.unsqueeze(1) * (X.unsqueeze(2).expand({X.size(0), X.size(1), C.size(0), C.size(1)}) - C.unsqueeze(0).unsqueeze(0))).sum(3); auto gradX = at::bmm(A, GE); auto gradC = (-GE * A.sum(1).unsqueeze(2)).sum(0); return {gradA, gradX, gradC}; } at::Tensor ScaledL2_Forward_CPU( const at::Tensor X, const at::Tensor C, const at::Tensor S) { auto SL = S.view({1, 1, C.size(0)}) * (X.unsqueeze(2).expand({X.size(0), X.size(1), C.size(0), C.size(1)}) - C.unsqueeze(0).unsqueeze(0)).pow(2).sum(3); return SL; } std::vector ScaledL2_Backward_CPU( const at::Tensor GSL, const at::Tensor X, const at::Tensor C, const at::Tensor S, const at::Tensor SL) { auto tmp = (2 * GSL * S.view({1, 1, C.size(0)})).unsqueeze(3) * (X.unsqueeze(2).expand({X.size(0), X.size(1), C.size(0), C.size(1)}) - C.unsqueeze(0).unsqueeze(0)); auto GX = tmp.sum(2); auto GC = tmp.sum(0).sum(0); auto GS = (GSL * (SL / S.view({1, 1, C.size(0)}))).sum(0).sum(0); return {GX, GC, GS}; }