Unverified Commit b2ce5906 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

adding routines for cleaning up distributed process groups (#50)

parent 24fcb06e
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
# build after cloning in directoy torch_harmonics via # build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics # docker build . -t torch_harmonics
FROM nvcr.io/nvidia/pytorch:24.07-py3 FROM nvcr.io/nvidia/pytorch:24.08-py3
COPY . /workspace/torch_harmonics COPY . /workspace/torch_harmonics
...@@ -38,6 +38,7 @@ COPY . /workspace/torch_harmonics ...@@ -38,6 +38,7 @@ COPY . /workspace/torch_harmonics
RUN pip install parameterized RUN pip install parameterized
# The custom CUDA extension does not suppport architerctures < 7.0 # The custom CUDA extension does not suppport architerctures < 7.0
ENV FORCE_CUDA_EXTENSION=1
ENV TORCH_CUDA_ARCH_LIST "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX" ENV TORCH_CUDA_ARCH_LIST "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
RUN pip install --global-option --cuda_ext /workspace/torch_harmonics RUN cd /workspace/torch_harmonics && pip install --no-build-isolation .
...@@ -112,6 +112,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -112,6 +112,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# initializing sht # initializing sht
thd.init(cls.h_group, cls.w_group) thd.init(cls.h_group, cls.w_group)
@classmethod
def tearDownClass(cls):
thd.finalize()
dist.destroy_process_group(None)
def _split_helper(self, tensor): def _split_helper(self, tensor):
with torch.no_grad(): with torch.no_grad():
# split in W # split in W
...@@ -185,7 +190,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -185,7 +190,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5], [129, 256, 129, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5], [64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-5],
......
...@@ -118,6 +118,11 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -118,6 +118,11 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
# initializing sht # initializing sht
thd.init(cls.h_group, cls.w_group) thd.init(cls.h_group, cls.w_group)
@classmethod
def tearDownClass(cls):
thd.finalize()
dist.destroy_process_group(None)
def _split_helper(self, tensor): def _split_helper(self, tensor):
with torch.no_grad(): with torch.no_grad():
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
# #
# we need this in order to enable distributed # we need this in order to enable distributed
from .utils import init, is_initialized, polar_group, azimuth_group from .utils import init, finalize, is_initialized, polar_group, azimuth_group
from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank
from .primitives import compute_split_shapes, split_tensor_along_dim from .primitives import compute_split_shapes, split_tensor_along_dim
from .primitives import ( from .primitives import (
......
...@@ -51,6 +51,13 @@ def init(polar_process_group, azimuth_process_group): ...@@ -51,6 +51,13 @@ def init(polar_process_group, azimuth_process_group):
_AZIMUTH_PARALLEL_GROUP = azimuth_process_group _AZIMUTH_PARALLEL_GROUP = azimuth_process_group
_IS_INITIALIZED = True _IS_INITIALIZED = True
def finalize():
if is_initialized():
if is_distributed_polar():
dist.destroy_process_group(_POLAR_PARALLEL_GROUP)
if is_distributed_azimuth():
ist.destroy_process_group(_AZIMUTH_PARALLEL_GROUP)
def is_initialized() -> bool: def is_initialized() -> bool:
return _IS_INITIALIZED return _IS_INITIALIZED
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment