Commit a9a3ef50 authored by rprenger's avatar rprenger
Browse files

Simpler broadcasting and some clean up

parent 5580d661
...@@ -54,17 +54,11 @@ class MegatronGenerate(Resource): ...@@ -54,17 +54,11 @@ class MegatronGenerate(Resource):
# Send the sizes of the tensors # Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len] input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len]
input_info_tensor = torch.cuda.LongTensor(input_info) input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.broadcast(input_info_tensor, torch.distributed.broadcast(input_info_tensor, 0)
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
# Now send tensors # Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, torch.distributed.broadcast(context_length_tensor, 0)
mpu.get_tensor_model_parallel_src_rank(), torch.distributed.broadcast(context_tokens_tensor, 0)
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
@staticmethod @staticmethod
def receive_generate_info(): def receive_generate_info():
...@@ -72,9 +66,7 @@ class MegatronGenerate(Resource): ...@@ -72,9 +66,7 @@ class MegatronGenerate(Resource):
Needs to be synced up with send_generate_info Needs to be synced up with send_generate_info
""" """
input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda")) input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda"))
torch.distributed.broadcast(input_info_tensor, torch.distributed.broadcast(input_info_tensor, 0)
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
batch_size = input_info_tensor[0].item() batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item() seq_len = input_info_tensor[1].item()
max_len = input_info_tensor[2].item() max_len = input_info_tensor[2].item()
...@@ -82,12 +74,10 @@ class MegatronGenerate(Resource): ...@@ -82,12 +74,10 @@ class MegatronGenerate(Resource):
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda")) context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda"))
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda")) context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda"))
torch.distributed.broadcast(context_length_tensor, # Send variables to all ranks
mpu.get_tensor_model_parallel_src_rank(), torch.distributed.broadcast(context_length_tensor, 0)
group=mpu.get_tensor_model_parallel_group()) torch.distributed.broadcast(context_tokens_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
return context_length_tensor, context_tokens_tensor, max_len return context_length_tensor, context_tokens_tensor, max_len
@staticmethod @staticmethod
...@@ -100,22 +90,26 @@ class MegatronGenerate(Resource): ...@@ -100,22 +90,26 @@ class MegatronGenerate(Resource):
return decode_tokens return decode_tokens
def put(self): def put(self):
args = get_args()
sentences = request.get_json()["sentences"] sentences = request.get_json()["sentences"]
max_len = 1024 # TODO (rprenger) this should not be hardcoded max_len = args.seq_length
if "max_len" in request.get_json(): if "max_len" in request.get_json():
max_len = request.get_json()["max_len"] input_max_len = request.get_json()["max_len"]
if input_max_len < args.seq_length:
max_len = input_max_len
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
MegatronGenerate.send_generate_info(context_tokens_tensor, context_length_tensor, max_len) # Send them info MegatronGenerate.send_generate_info(context_tokens_tensor, context_length_tensor, max_len) # Send them info
decode_tokens = MegatronGenerate.do_generate(self.model, context_length_tensor, context_tokens_tensor, max_len) # Do stuff decode_tokens = MegatronGenerate.do_generate(self.model, context_length_tensor, context_tokens_tensor, max_len) # Do stuff
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() resp_sentences = []
trim_decode_tokens = tokenizer.detokenize(decode_tokens) for i in range(decode_tokens.size(0)):
return jsonify({"sentences": [trim_decode_tokens]}) decode_token = decode_tokens[i,:].cpu().numpy().tolist()
resp_sentences.append(tokenizer.detokenize(decode_token))
return jsonify({"sentences": resp_sentences})
class MegatronServer(object): class MegatronServer(object):
......
...@@ -40,7 +40,8 @@ def get_batch(context_tokens): ...@@ -40,7 +40,8 @@ def get_batch(context_tokens):
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
# Move to GPU. # Move to GPU.
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda() tokens = context_tokens.contiguous().cuda()
# Get the attention mask and postition ids. # Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids( attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
...@@ -464,7 +465,6 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -464,7 +465,6 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
def sample_sequence_batch(model, context_tokens, context_lengths, def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids, attention_mask, position_ids,
maxlen=None, type_ids=None): maxlen=None, type_ids=None):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment