pbc_graph.py 5.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Copyright (c) 2025 Ma Zhaojia

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.

CUDA-accelerated PBC graph operations for atomic systems.
"""
import torch
from typing import Optional, List
from .pbc_graph_legacy import get_max_neighbors_mask
from .extensions.cuda_ops import get_pbc_graph_cuda_extension

def radius_graph_pbc_cuda(
    data,
    radius,
    max_num_neighbors_threshold,
    enforce_max_neighbors_strictly: bool = False,
    pbc=None,
    dtype=torch.float64,
):
    """
    Memory-efficient CUDA-accelerated version of radius_graph_pbc.
    
    This implementation follows the memory-efficient approach with triple loops
    but accelerates the distance computation using CUDA kernels.
    """
    if pbc is None:
        pbc = [True, True, True]
    
    device = data.pos.device
    batch_size = len(data.natoms)
    
    # Handle PBC settings
    if hasattr(data, "pbc"):
        data.pbc = torch.atleast_2d(data.pbc)
        for i in range(3):
            if not torch.any(data.pbc[:, i]).item():
                pbc[i] = False
            elif torch.all(data.pbc[:, i]).item():
                pbc[i] = True
            else:
                raise RuntimeError(
                    "Different structures in the batch have different PBC configurations."
                )

    # position of the atoms
    atom_pos = data.pos

    # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
    num_atoms_per_image = data.natoms
    num_atoms_per_image_sqr = (num_atoms_per_image**2).long()

    # index offset between images
    index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image

    index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr)
    num_atoms_per_image_expand = torch.repeat_interleave(
        num_atoms_per_image, num_atoms_per_image_sqr
    )

    # Compute atom pair indices 
    num_atom_pairs = torch.sum(num_atoms_per_image_sqr)
    index_sqr_offset = (
        torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr
    )
    index_sqr_offset = torch.repeat_interleave(
        index_sqr_offset, num_atoms_per_image_sqr
    )
    atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset

    # Compute the indices for the pairs of atoms (using division and mod)
    index1 = (
        torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor")
    ) + index_offset_expand
    index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand
    # Get the positions for each atom
    pos1 = torch.index_select(atom_pos, 0, index1)
    pos2 = torch.index_select(atom_pos, 0, index2)

    # Calculate required number of unit cells in each direction for PBC
    cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1)
    cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True)

    if pbc[0]:
        inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1)
        rep_a1 = torch.ceil(radius * inv_min_dist_a1)
    else:
        rep_a1 = data.cell.new_zeros(1)

    if pbc[1]:
        cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1)
        inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1)
        rep_a2 = torch.ceil(radius * inv_min_dist_a2)
    else:
        rep_a2 = data.cell.new_zeros(1)

    if pbc[2]:
        cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1)
        inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1)
        rep_a3 = torch.ceil(radius * inv_min_dist_a3)
    else:
        rep_a3 = data.cell.new_zeros(1)

    # Take the max over all images for uniformity
    max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())]

    # Pre-transpose data_cell for efficiency
    data_cell = torch.transpose(data.cell, 1, 2)
    
    # Use CUDA kernel for the triple loop computation
    # try:
    pbc_graph_cuda = get_pbc_graph_cuda_extension()
    
    # Call the CUDA implementation
    valid_pair_indices, unit_cell, atom_distance_sqr = pbc_graph_cuda.pbc_distance_cuda(
        pos1, pos2, data_cell, 
        num_atoms_per_image_sqr, batch_size, max_rep, float(radius), device
    )
    
    # Map back to original index1 and index2
    if len(valid_pair_indices) > 0:
        index1 = index1.index_select(0, valid_pair_indices.long())
        index2 = index2.index_select(0, valid_pair_indices.long())
    else:
        index1 = torch.empty(0, dtype=torch.long, device=device)
        index2 = torch.empty(0, dtype=torch.long, device=device)
        unit_cell = torch.empty(0, 3, dtype=dtype, device=device)
        atom_distance_sqr = torch.empty(0, dtype=dtype, device=device)

    # Sort index1 in ascending order and rearrange other arrays correspondingly
    if len(index1) > 0:
        sort_indices = torch.argsort(index1)
        index1 = index1[sort_indices]
        index2 = index2[sort_indices]
        unit_cell = unit_cell[sort_indices]
        atom_distance_sqr = atom_distance_sqr[sort_indices]

    mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask(
        natoms=data.natoms,
        index=index1,
        atom_distance=atom_distance_sqr,
        max_num_neighbors_threshold=max_num_neighbors_threshold,
        enforce_max_strictly=enforce_max_neighbors_strictly,
    )

    if not torch.all(mask_num_neighbors):
        # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
        index1 = torch.masked_select(index1, mask_num_neighbors)
        index2 = torch.masked_select(index2, mask_num_neighbors)
        unit_cell = torch.masked_select(
            unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3)
        )
        unit_cell = unit_cell.view(-1, 3)

    edge_index = torch.stack((index2, index1))

zcxzcx1's avatar
zcxzcx1 committed
158
    return edge_index, unit_cell, num_neighbors_image