"vscode:/vscode.git/clone" did not exist on "4e768bf3f6dff04e6c7bc5cc9b173a796887a30f"
Commit 59414b33 authored by rprenger's avatar rprenger
Browse files

Made some simplifications that keep it from hanging

parent 7a9c4a03
...@@ -41,9 +41,11 @@ class MegatronGenerate(Resource): ...@@ -41,9 +41,11 @@ class MegatronGenerate(Resource):
max_len = 64 # Choosing hopefully sane default. Full sequence is slow max_len = 64 # Choosing hopefully sane default. Full sequence is slow
if "max_len" in request.get_json(): if "max_len" in request.get_json():
input_max_len = request.get_json()["max_len"] max_len = request.get_json()["max_len"]
if input_max_len < args.seq_length: if not isinstance(max_len, int):
max_len = input_max_len return "max_len must be an integer greater than 0"
if max_len < 1:
return "max_len must be an integer greater than 0"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences = generate(self.model, sentences, max_len) resp_sentences = generate(self.model, sentences, max_len)
......
...@@ -104,21 +104,6 @@ def tokenize_batch(sentences): ...@@ -104,21 +104,6 @@ def tokenize_batch(sentences):
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
return context_tokens_tensor, context_length_tensor return context_tokens_tensor, context_length_tensor
def get_token_stream(model, context_tokens_tensor, context_length_tensor):
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
def send_generate_info(context_tokens_tensor, context_length_tensor, max_len): def send_generate_info(context_tokens_tensor, context_length_tensor, max_len):
""" """
Needs to be synced up with receive_generate_info Needs to be synced up with receive_generate_info
...@@ -151,13 +136,19 @@ def receive_generate_info(): ...@@ -151,13 +136,19 @@ def receive_generate_info():
return context_length_tensor, context_tokens_tensor, max_len return context_length_tensor, context_tokens_tensor, max_len
def synced_generate(model, context_length_tensor, context_tokens_tensor, max_len): def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len):
token_stream = get_token_stream(model, context_tokens_tensor, context_length_tensor) context_length = context_length_tensor.min().item()
for i, decode_tokens in enumerate(token_stream): tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
if i == max_len-1:
break batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
pass context_length_tensor,
return decode_tokens attention_mask, position_ids,
max_len)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
return tokens[:, :context_length]
def generate(model, sentences=None, max_len=0): def generate(model, sentences=None, max_len=0):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -169,12 +160,11 @@ def generate(model, sentences=None, max_len=0): ...@@ -169,12 +160,11 @@ def generate(model, sentences=None, max_len=0):
else: else:
context_length_tensor, context_tokens_tensor, max_len = receive_generate_info() context_length_tensor, context_tokens_tensor, max_len = receive_generate_info()
decode_tokens = synced_generate(model, context_length_tensor, context_tokens_tensor, max_len) decode_tokens = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
decode_tokens, _ = decode_tokens
resp_sentences = [] resp_sentences = []
for i in range(decode_tokens.size(0)): for i in range(decode_tokens.size(0)):
decode_token = decode_tokens[i,:].cpu().numpy().tolist() decode_token = decode_tokens[i,:].cpu().numpy().tolist()
...@@ -248,9 +238,12 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -248,9 +238,12 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
tokens = context_tokens tokens = context_tokens
if maxlen is None: if maxlen is None:
maxlen = args.seq_length - 1 maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length maxlen = maxlen + org_context_length
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen): while context_length <= (maxlen):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment