Unverified Commit 96780e6e authored by ver217's avatar ver217 Committed by GitHub
Browse files

Optimize pipeline schedule (#94)



* add pipeline shared module wrapper and update load batch

* added model parallel process group for amp and clip grad (#86)

* added model parallel process group for amp and clip grad

* update amp and clip with model parallel process group

* remove pipeline_prev/next group (#88)

* micro batch offload

* optimize pipeline gpu memory usage

* pipeline can receive tensor shape (#93)

* optimize pipeline gpu memory usage

* fix grad accumulation step counter

* rename classes and functions
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
parent e5b9f9a0
...@@ -155,22 +155,12 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -155,22 +155,12 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in params) total_norm = max(p.grad.data.abs().max() for p in params)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
ops = []
# Take max across all model-parallel GPUs. # Take max across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1: if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
ops.append(dist.all_reduce(total_norm_cuda, dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX, op=dist.ReduceOp.MAX,
group=gpc.get_group( group=gpc.get_group(ParallelMode.MODEL),
ParallelMode.TENSOR), async_op=False)
async_op=True))
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
ops.append(dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(
ParallelMode.PIPELINE),
async_op=True))
for req in ops:
req.wait()
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
else: else:
tensor_parallel_grads = [] tensor_parallel_grads = []
......
...@@ -65,6 +65,7 @@ class GradAccumOptimizer(ColossalaiOptimizer): ...@@ -65,6 +65,7 @@ class GradAccumOptimizer(ColossalaiOptimizer):
self.optim.backward(scaled_loss) self.optim.backward(scaled_loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor): def backward_by_grad(self, tensor: Tensor, grad: Tensor):
self.accumulate_step += 1
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
if no_sync: if no_sync:
...@@ -81,7 +82,7 @@ class GradAccumDataloader(): ...@@ -81,7 +82,7 @@ class GradAccumDataloader():
be update only twice at step 4 and step 8. The last two batches of data do not form a complete 4-step cycle. be update only twice at step 4 and step 8. The last two batches of data do not form a complete 4-step cycle.
Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader, Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader,
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches. (e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
:param dataloader: your dataloader object :param dataloader: your dataloader object
:type dataloader: Iterable :type dataloader: Iterable
:param accumulate_size: the number of steps to accumulate gradients :param accumulate_size: the number of steps to accumulate gradients
......
...@@ -26,8 +26,6 @@ follow the steps below to create a new distributed initialization. ...@@ -26,8 +26,6 @@ follow the steps below to create a new distributed initialization.
GLOBAL = 'global' GLOBAL = 'global'
DATA = 'data' DATA = 'data'
PIPELINE = 'pipe' PIPELINE = 'pipe'
PIPELINE_PREV = 'pipe_prev'
PIPELINE_NEXT = 'pipe_next'
... ...
NEW_MODE = 'new_mode' # define your mode here NEW_MODE = 'new_mode' # define your mode here
......
...@@ -18,8 +18,6 @@ class ParallelMode(Enum): ...@@ -18,8 +18,6 @@ class ParallelMode(Enum):
GLOBAL = 'global' GLOBAL = 'global'
DATA = 'data' DATA = 'data'
PIPELINE = 'pipe' PIPELINE = 'pipe'
PIPELINE_PREV = 'pipe_prev'
PIPELINE_NEXT = 'pipe_next'
... ...
NEW_MODE = 'new_mode' # define your mode here NEW_MODE = 'new_mode' # define your mode here
......
...@@ -33,6 +33,12 @@ def check_pipeline_parallel_rank(rank): ...@@ -33,6 +33,12 @@ def check_pipeline_parallel_rank(rank):
assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1 assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
def check_model_parallel_rank(rank):
for i in range(8):
if rank in [i, i+8]:
assert gpc.get_local_rank(ParallelMode.MODEL) == i
def check_tensor_parallel_rank(rank): def check_tensor_parallel_rank(rank):
if rank in [0, 4, 8, 12]: if rank in [0, 4, 8, 12]:
assert gpc.get_local_rank(ParallelMode.TENSOR) == 0 assert gpc.get_local_rank(ParallelMode.TENSOR) == 0
...@@ -75,6 +81,7 @@ def init_2d(rank, world_size, backend, port, host): ...@@ -75,6 +81,7 @@ def init_2d(rank, world_size, backend, port, host):
check_data_parallel_rank(rank) check_data_parallel_rank(rank)
check_2d_parallel_rank(rank) check_2d_parallel_rank(rank)
check_pipeline_parallel_rank(rank) check_pipeline_parallel_rank(rank)
check_model_parallel_rank(rank)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank): ...@@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank):
assert ppr == 1 assert ppr == 1
def check_model_parallel_rank(rank):
for i in range(16):
if rank in [i, i+16]:
assert gpc.get_local_rank(ParallelMode.MODEL) == i
def check_tensor_parallel_rank(rank): def check_tensor_parallel_rank(rank):
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
...@@ -98,6 +104,7 @@ def init_2halfd(rank, world_size, backend, port, host): ...@@ -98,6 +104,7 @@ def init_2halfd(rank, world_size, backend, port, host):
check_pipeline_parallel_rank(rank) check_pipeline_parallel_rank(rank)
check_tensor_parallel_rank(rank) check_tensor_parallel_rank(rank)
check_2p5d_parallel_rank(rank) check_2p5d_parallel_rank(rank)
check_model_parallel_rank(rank)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank): ...@@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank):
assert ppr == 1 assert ppr == 1
def check_model_parallel_rank(rank):
for i in range(16):
if rank in [i, i+16]:
assert gpc.get_local_rank(ParallelMode.MODEL) == i
def check_tensor_parallel_rank(rank): def check_tensor_parallel_rank(rank):
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
...@@ -90,6 +96,7 @@ def init_3d(rank, world_size, backend, port, host): ...@@ -90,6 +96,7 @@ def init_3d(rank, world_size, backend, port, host):
check_3d_parallel_rank(rank) check_3d_parallel_rank(rank)
check_data_parallel_rank(rank) check_data_parallel_rank(rank)
check_pipeline_parallel_rank(rank) check_pipeline_parallel_rank(rank)
check_model_parallel_rank(rank)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -23,7 +23,7 @@ BATCH_SIZE = 16 ...@@ -23,7 +23,7 @@ BATCH_SIZE = 16
NUM_EPOCHS = 60 NUM_EPOCHS = 60
WARMUP_EPOCHS = 5 WARMUP_EPOCHS = 5
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.TORCH), fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2) gradient_accumulation=2)
......
...@@ -75,49 +75,15 @@ def check_forward_backward(output_tensor, output_grad, rank, logger): ...@@ -75,49 +75,15 @@ def check_forward_backward(output_tensor, output_grad, rank, logger):
rank, check_equal(grad, output_grad))) rank, check_equal(grad, output_grad)))
def check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger): def check_comm(size, rank, prev_rank, next_rank, logger):
dtype = torch.float32 dtype = torch.float32
device = get_current_device() device = get_current_device()
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
# recv_tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
tensor = torch.randn(tensor_shape, dtype=dtype, device=device) tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
dist.all_reduce(tensor) dist.all_reduce(tensor)
grad = torch.randn(grad_shape, dtype=dtype, device=device) grad = torch.randn(grad_shape, dtype=dtype, device=device)
dist.all_reduce(grad) dist.all_reduce(grad)
if rank % 2 == 0:
need_meta = True
need_meta = send_tensor_meta(tensor, need_meta)
logger.info('Rank {} shape sent (need meta: {}).'.format(
rank, need_meta))
req = dist.broadcast(tensor, src=rank, group=down_group, async_op=True)
req.wait()
out = tensor.clone()
logger.info('Rank {} test op: tensor sent.'.format(rank))
else:
recv_tensor_shape = recv_tensor_meta(None)
logger.info('Rank {} shape received. Correct shape: {}'.format(
rank, tensor_shape == recv_tensor_shape))
out = torch.empty(recv_tensor_shape, dtype=dtype, device=device)
req = dist.broadcast(out, src=prev_rank, group=up_group, async_op=True)
req.wait()
logger.info('Rank {} test op: received tensor ({})'.format(
rank, out.shape))
logger.info('Rank {} test op. Correct tensor: {}'.format(
rank, check_equal(tensor, out)))
def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
dtype = torch.float32
device = get_current_device()
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
dist.all_reduce(tensor)
grad = torch.randn(grad_shape, dtype=dtype, device=device)
dist.all_reduce(grad)
check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger)
check_forward(tensor, rank, logger) check_forward(tensor, rank, logger)
check_backward(grad, rank, logger) check_backward(grad, rank, logger)
check_forward_backward(tensor, grad, rank, logger) check_forward_backward(tensor, grad, rank, logger)
...@@ -135,18 +101,13 @@ def run_check(rank, world_size, port): ...@@ -135,18 +101,13 @@ def run_check(rank, world_size, port):
logger = get_dist_logger() logger = get_dist_logger()
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
up_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_PREV)
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
down_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_NEXT)
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
logger.info( logger.info(
'Rank {0}: prev rank {1} (up: {2}), next rank {3} (down: {4})'.format( 'Rank {0}: prev rank {1}, next rank {2}'.format(
rank, prev_rank, up_ranks, next_rank, down_ranks)) rank, prev_rank, next_rank))
logger.info('Distributed environment is initialzied.') logger.info('Distributed environment is initialzied.')
check_comm(world_size, rank, prev_rank, next_rank, up_group, down_group, check_comm(world_size, rank, prev_rank, next_rank, logger)
logger)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment