Commit 6e7e2d90 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Rename peer_memory extension to peer_memory_cuda

parent fa8e7d99
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
import nccl_p2p as inc import nccl_p2p as inc
import peer_memory as pm import peer_memory_cuda as pm
# Communication free halo exchanger. # Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs # NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
......
import torch import torch
from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d
import peer_memory as pm import peer_memory_cuda as pm
# How to run: # How to run:
# torchrun --nproc_per_node <num-GPU> <this-python-prog> # torchrun --nproc_per_node <num-GPU> <this-python-prog>
......
import torch import torch
from apex.contrib.peer_memory import PeerMemoryPool from apex.contrib.peer_memory import PeerMemoryPool
import peer_memory as pm import peer_memory_cuda as pm
class PeerHaloExchanger1d: class PeerHaloExchanger1d:
def __init__(self, rank, peer_group_size, peer_pool, half_halo): def __init__(self, rank, peer_group_size, peer_pool, half_halo):
......
import torch import torch
import numpy as np import numpy as np
import peer_memory as pm import peer_memory_cuda as pm
class PeerMemoryPool(object): class PeerMemoryPool(object):
......
...@@ -632,7 +632,7 @@ if "--peer_memory" in sys.argv: ...@@ -632,7 +632,7 @@ if "--peer_memory" in sys.argv:
raise_if_cuda_home_none("--peer_memory") raise_if_cuda_home_none("--peer_memory")
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name="peer_memory", name="peer_memory_cuda",
sources=[ sources=[
"apex/contrib/csrc/peer_memory/peer_memory_cuda.cu", "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu",
"apex/contrib/csrc/peer_memory/peer_memory.cpp", "apex/contrib/csrc/peer_memory/peer_memory.cpp",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment