Commit e510b003 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Sample 1d peer memory halo exchanger

parent a61f0c25
from .peer_memory import PeerMemoryPool from .peer_memory import PeerMemoryPool
from .peer_halo_exchanger_1d import PeerHaloExchanger1d
import torch import torch
from apex.contrib.peer_memory import PeerMemoryPool from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d
import peer_memory as pm import peer_memory as pm
# How to run:
class HaloExchangerPeerMemory: # torchrun --nproc_per_node <num-GPU> <this-python-prog>
def __init__(self, rank, peer_group_size, peer_pool): # <num-GPU> must be a power of 2 greater than 1.
self.peer_group_size = peer_group_size
self.peer_rank = rank % peer_group_size
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
def __call__(self, y, half_halo, H_split=True, explicit_nhwc=False, numSM=1):
channels_last = y.is_contiguous(memory_format=torch.channels_last)
if H_split:
if explicit_nhwc:
_, Hs, _, _ = list(y.shape)
H = Hs - 2*half_halo
top_out_halo = y[:,half_halo:2*half_halo,:,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:half_halo,:,:]
btm_out_halo = y[:,H:H+half_halo,:,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,H+half_halo:H+2*half_halo,:,:]
else:
_, _, Hs, _ = list(y.shape)
H = Hs - 2*half_halo
top_out_halo = y[:,:,half_halo:2*half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:half_halo,:]
btm_out_halo = y[:,:,H:H+half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,H+half_halo:H+2*half_halo,:]
else:
if explicit_nhwc:
_, _, Ws, _ = list(y.shape)
W = Ws - 2*half_halo
top_out_halo = y[:,:,half_halo:2*half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:,:half_halo,:]
btm_out_halo = y[:,:,W:W+half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,:,W+half_halo:W+2*half_halo,:]
else:
_, _, _, Ws = list(y.shape)
W = Ws - 2*half_halo
top_out_halo = y[:,:,:,half_halo:2*half_halo]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:,:half_halo]
btm_out_halo = y[:,:,:,W:W+half_halo]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,:,W+half_halo:W+2*half_halo]
top_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size
btm_neighbor = (self.peer_rank + 1) % self.peer_group_size
pm.push_pull_halos_1d(
False, #True if self.peer_rank == 0 else False,
explicit_nhwc, numSM,
top_out_halo, top_tx[self.peer_rank], btm_tx[top_neighbor], top_inp_halo,
btm_out_halo, btm_tx[self.peer_rank], top_tx[btm_neighbor], btm_inp_halo,
self.signals[top_neighbor], self.signals[btm_neighbor], self.signals[self.peer_rank]
)
# Output of this function is used as ground truth in module tests.
def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_split): def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_split):
if explicit_nhwc: if explicit_nhwc:
if H_split: if H_split:
...@@ -132,7 +78,7 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, ...@@ -132,7 +78,7 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
y = y.to(memory_format=torch.channels_last) y = y.to(memory_format=torch.channels_last)
ym = y[:,:,:,half_halo:W+half_halo] ym = y[:,:,:,half_halo:W+half_halo]
y2 = y.clone() y2 = y.clone()
halo_ex(y, half_halo, H_split, explicit_nhwc, numSM) halo_ex(y, H_split, explicit_nhwc, numSM)
nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split) nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split)
is_equal = torch.all(torch.eq(y,y2)) is_equal = torch.all(torch.eq(y,y2))
if peer_rank == 0: if peer_rank == 0:
...@@ -184,12 +130,11 @@ def main(): ...@@ -184,12 +130,11 @@ def main():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024) pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024)
halo_ex = HaloExchangerPeerMemory(rank, world_size, pool)
half_halo = 1 half_halo = 1
halo_ex = PeerHaloExchanger1d(rank, world_size, pool, half_halo)
H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex) H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex)
W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex) W_split_tests(1,64,200,336, half_halo,world_size,halo_ex)
if __name__ == "__main__": if __name__ == "__main__":
......
import torch
from apex.contrib.peer_memory import PeerMemoryPool
import peer_memory as pm
class PeerHaloExchanger1d:
def __init__(self, rank, peer_group_size, peer_pool, half_halo):
self.peer_group_size = peer_group_size
self.peer_rank = rank % peer_group_size
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
self.half_halo = half_halo
def __call__(self, y, H_split=True, explicit_nhwc=False, numSM=1, diagnostics=False):
channels_last = y.is_contiguous(memory_format=torch.channels_last) and not explicit_nhwc
if H_split:
if explicit_nhwc:
_, Hs, _, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,self.half_halo:2*self.half_halo,:,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:self.half_halo,:,:]
btm_out_halo = y[:,H:H+self.half_halo,:,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:]
else:
_, _, Hs, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,H:H+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:]
else:
if explicit_nhwc:
_, _, Ws, _ = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,W:W+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:]
else:
_, _, _, Ws = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,:,self.half_halo:2*self.half_halo]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:,:self.half_halo]
btm_out_halo = y[:,:,:,W:W+self.half_halo]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo]
top_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size
btm_neighbor = (self.peer_rank + 1) % self.peer_group_size
pm.push_pull_halos_1d(
diagnostics, explicit_nhwc, numSM,
top_out_halo, top_tx[self.peer_rank], btm_tx[top_neighbor], top_inp_halo,
btm_out_halo, btm_tx[self.peer_rank], top_tx[btm_neighbor], btm_inp_halo,
self.signals[top_neighbor], self.signals[btm_neighbor], self.signals[self.peer_rank]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment