Unverified Commit 73ff4f3a authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent fb246ae0
"""basic scatter_sum operations from torch_scatter from
https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py
Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency.
PyTorch plans to move these features into the main repo, but until then,
to make installation simpler, we need this pure python set of wrappers
that don't require installing PyTorch C++ extensions.
See https://github.com/pytorch/pytorch/issues/63780.
"""
from typing import Optional
import torch
def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
def scatter_sum(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = "sum",
) -> torch.Tensor:
assert reduce == "sum" # for now, TODO
index = _broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
def scatter_std(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
unbiased: bool = True,
) -> torch.Tensor:
if out is not None:
dim_size = out.size(dim)
if dim < 0:
dim = src.dim() + dim
count_dim = dim
if index.dim() <= dim:
count_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, count_dim, dim_size=dim_size)
index = _broadcast(index, src, dim)
tmp = scatter_sum(src, index, dim, dim_size=dim_size)
count = _broadcast(count, tmp, dim).clamp(1)
mean = tmp.div(count)
var = src - mean.gather(dim, index)
var = var * var
out = scatter_sum(var, index, dim, out, dim_size)
if unbiased:
count = count.sub(1).clamp_(1)
out = out.div(count + 1e-6).sqrt()
return out
def scatter_mean(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
) -> torch.Tensor:
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)
index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count[count < 1] = 1
count = _broadcast(count, out, dim)
if out.is_floating_point():
out.true_divide_(count)
else:
out.div_(count, rounding_mode="floor")
return out
This diff is collapsed.
###########################################################################################
# Slurm environment setup for distributed training.
# This code is refactored from rsarm's contribution at:
# https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import os
import hostlist
class DistributedEnvironment:
def __init__(self):
self._setup_distr_env()
self.master_addr = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]
self.world_size = int(os.environ["WORLD_SIZE"])
self.local_rank = int(os.environ["LOCAL_RANK"])
self.rank = int(os.environ["RANK"])
def _setup_distr_env(self):
hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0]
os.environ["MASTER_ADDR"] = hostname
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333")
os.environ["WORLD_SIZE"] = os.environ.get(
"SLURM_NTASKS",
str(
int(os.environ["SLURM_NTASKS_PER_NODE"])
* int(os.environ["SLURM_NNODES"])
),
)
os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
os.environ["RANK"] = os.environ["SLURM_PROCID"]
def __repr__(self):
return (
f"DistributedEnvironment(master_addr={self.master_addr}, master_port={self.master_port}, "
f"world_size={self.world_size}, local_rank={self.local_rank}, rank={self.rank})"
)
This diff is collapsed.
# Trimmed-down `pytorch_geometric`
MACE uses [`pytorch_geometric`](https://pytorch-geometric.readthedocs.io/en/latest/) [1, 2] framework. However as only use a very limited subset of that library: the most basic graph data structures.
We follow the same approach to NequIP (https://github.com/mir-group/nequip/tree/main/nequip) and copy their code here.
To avoid adding a large number of unnecessary second-degree dependencies, and to simplify installation, we include and modify here the small subset of `torch_geometric` that is necessary for our code.
We are grateful to the developers of PyTorch Geometric for their ongoing and very useful work on graph learning with PyTorch.
[1] Fey, M., & Lenssen, J. E. (2019). Fast Graph Representation Learning with PyTorch Geometric (Version 2.0.1) [Computer software]. https://github.com/pyg-team/pytorch_geometric <br>
[2] https://arxiv.org/abs/1903.02428
from .batch import Batch
from .data import Data
from .dataloader import DataLoader
from .dataset import Dataset
from .seed import seed_everything
__all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment