projection_deleteme.py 2 KB
Newer Older
Stella Biderman's avatar
Stella Biderman committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch

def batch_vector_projection(vectors: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Projects each vector in a batch onto a target vector.
    
    Args:
        vectors: Tensor of shape (b, p, d) where:
                b is the batch size
                p is the number of vectors per batch
                d is the dimension of each vector
        target: Tensor of shape (d,) - the vector to project onto
                
    Returns:
        Tensor of shape (b, p, d) containing the projected vectors
        
    Example:
        b, p, d = 32, 10, 3  # batch of 32, 10 vectors each, in 3D
        vectors = torch.randn(b, p, d)
        target = torch.randn(d)
        projections = batch_vector_projection(vectors, target)
    """
    # Ensure target is unit vector
    target = torch.nn.functional.normalize(target, dim=0)
    
    # Reshape target to (1, 1, d) for broadcasting
    target_reshaped = target.view(1, 1, -1)
    
    # Compute dot product between each vector and target
    # Result shape: (b, p, 1)
    dot_products = torch.sum(vectors * target_reshaped, dim=-1, keepdim=True)
    
    # Project each vector onto target
    # Multiply dot products by target vector
    # Result shape: (b, p, d)
    projections = dot_products * target_reshaped
    
    return projections, dot_products

# Test function
if __name__ == "__main__":
    # Create sample data
    batch_size, vectors_per_batch, dim = 2, 3, 4
    vectors = torch.randn(batch_size, vectors_per_batch, dim)
    target = torch.randn(dim)
    
    # Compute projections
    projected, dot_products = batch_vector_projection(vectors, target)
    
    _, zero_dot_products = batch_vector_projection(vectors - projected, target)
    assert torch.allclose(zero_dot_products, torch.zeros_like(zero_dot_products), atol=1e-6)
    print("Without proj, close to zero")
    # Verify shapes
    print(f"Input shape: {vectors.shape}")
    print(f"Target shape: {target.shape}")
    print(f"Output shape: {projected.shape}")