"src/vscode:/vscode.git/clone" did not exist on "5ffb73d4aeac9eaef8366d7b21872d64009bd1c7"
Commit 628bf0dd authored by Neel Kant's avatar Neel Kant
Browse files

Use the new allgather implementation

parent 98feae4e
...@@ -57,7 +57,6 @@ def model_provider(): ...@@ -57,7 +57,6 @@ def model_provider():
return general_model_provider(False, False) return general_model_provider(False, False)
def get_group_world_size_rank(): def get_group_world_size_rank():
group = mpu.get_data_parallel_group() group = mpu.get_data_parallel_group()
...@@ -67,23 +66,10 @@ def get_group_world_size_rank(): ...@@ -67,23 +66,10 @@ def get_group_world_size_rank():
return group, rank, world_size return group, rank, world_size
def get_rank_chunk_along_first_dim(tensor):
group, rank, world_size = get_group_world_size_rank()
assert tensor.shape[0] % world_size == 0
dim_size = tensor.shape[0] // world_size
output_list = torch.split(tensor, dim_size, dim=0)
output = output_list[rank].contiguous()
return output
class AllgatherFromDataParallelRegion(torch.autograd.Function): class AllgatherFromDataParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
assert input_.dim() == 2 assert input_.dim() == 2
group, rank, world_size = get_group_world_size_rank() group, rank, world_size = get_group_world_size_rank()
...@@ -98,32 +84,17 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -98,32 +84,17 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return get_rank_chunk_along_first_dim(grad_output)
class AllReduceFromDataParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
assert input_.dim() == 2
group, rank, world_size = get_group_world_size_rank() group, rank, world_size = get_group_world_size_rank()
tensor_list = [torch.zero_like(input_) for _ in range(world_size)] assert grad_output.shape[0] % world_size == 0
tensor_list[rank] = input_ dim_size = grad_output.shape[0] // world_size
output = torch.cat(tensor_list, dim=0).contiguous() output_list = torch.split(grad_output, dim_size, dim=0)
torch.distributed.all_reduce(output, group=group)
# get chunk from this rank
output = output_list[rank].contiguous()
return output return output
@staticmethod
def backward(ctx, grad_output):
return get_rank_chunk_along_first_dim(grad_output)
def get_batch(data_iterator): def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['query_tokens', 'query_pad_mask', keys = ['query_tokens', 'query_pad_mask',
...@@ -159,38 +130,14 @@ def forward_step(data_iterator, model): ...@@ -159,38 +130,14 @@ def forward_step(data_iterator, model):
block_tokens, block_pad_mask, block_indices = get_batch(data_iterator) block_tokens, block_pad_mask, block_indices = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1
IMPLEMENTATION = 'original' all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
if IMPLEMENTATION == 'original':
data_parallel_size = dist.get_world_size() / args.model_parallel_size
batch_size = query_logits.shape[0]
global_batch_size = int(batch_size * data_parallel_size)
all_logits_shape = (int(global_batch_size), int(query_logits.shape[1]))
all_query_logits = torch.cuda.FloatTensor(*all_logits_shape).type(query_logits.dtype).fill_(0.0)
all_block_logits = all_query_logits.clone()
# record this processes' data
all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits
# merge data from all processes
dist.all_reduce(all_query_logits)
dist.all_reduce(all_block_logits)
elif IMPLEMENTATION == 'allgather':
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
elif IMPLEMENTATION == 'allreduce':
all_query_logits = AllReduceFromDataParallelRegion.apply(query_logits)
all_block_logits = AllReduceFromDataParallelRegion.apply(block_logits)
else:
raise Exception('should not be here.')
# scores are inner products between query and block embeddings # scores are inner products between query and block embeddings
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float()) retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment