Commit 77979e3b authored by rprenger's avatar rprenger
Browse files

Changing api to tokens_to_generate, making it so we always generate at least tokens_to_generate

parent 42e83ee0
...@@ -12,10 +12,11 @@ ...@@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime
import torch import torch
import json
from flask import Flask, request, jsonify, current_app from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.text_generation_utils import generate from megatron.text_generation_utils import generate
...@@ -35,17 +36,20 @@ class MegatronGenerate(Resource): ...@@ -35,17 +36,20 @@ class MegatronGenerate(Resource):
def put(self): def put(self):
args = get_args() args = get_args()
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now())
sentences = request.get_json()["sentences"] sentences = request.get_json()["sentences"]
if len(sentences) > 128: if len(sentences) > 128:
return "Maximum number of sentences is 128", 400 return "Maximum number of sentences is 128", 400
max_len = 64 # Choosing hopefully sane default. Full sequence is slow tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
if "max_len" in request.get_json(): if "tokens_to_generate" in request.get_json():
max_len = request.get_json()["max_len"] tokens_to_generate = request.get_json()["tokens_to_generate"]
if not isinstance(max_len, int): if not isinstance(tokens_to_generate, int):
return "max_len must be an integer greater than 0" return "tokens_to_generate must be an integer greater than 0"
if max_len < 1: if tokens_to_generate < 1:
return "max_len must be an integer greater than 0" return "tokens_to_generate must be an integer greater than 0"
all_probs = False all_probs = False
if "all_probs" in request.get_json(): if "all_probs" in request.get_json():
...@@ -54,7 +58,7 @@ class MegatronGenerate(Resource): ...@@ -54,7 +58,7 @@ class MegatronGenerate(Resource):
return "all_probs must be a boolean value" return "all_probs must be a boolean value"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, max_len, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs)
if all_probs: if all_probs:
return jsonify({"sentences": resp_sentences, return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg, "segments": resp_sentences_seg,
...@@ -66,15 +70,12 @@ class MegatronGenerate(Resource): ...@@ -66,15 +70,12 @@ class MegatronGenerate(Resource):
"segments": resp_sentences_seg, "segments": resp_sentences_seg,
"logits": output_logits}) "logits": output_logits})
def index():
return current_app.send_static_file('index.html')
class MegatronServer(object): class MegatronServer(object):
def __init__(self, model): def __init__(self, model):
self.app = Flask(__name__) self.app = Flask(__name__, static_folder='static', static_url_path='')
self.app.add_url_rule('/', 'index', index) self.app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0
api = Api(self.app) api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])
def run(self, url): def run(self, url):
self.app.run(url, threaded=False, debug=False) self.app.run(url, threaded=True, debug=False)
...@@ -104,12 +104,12 @@ def tokenize_batch(sentences): ...@@ -104,12 +104,12 @@ def tokenize_batch(sentences):
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
return context_tokens_tensor, context_length_tensor return context_tokens_tensor, context_length_tensor
def send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs): def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
""" """
Needs to be synced up with receive_generate_info Needs to be synced up with receive_generate_info
""" """
# Send the sizes of the tensors # Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len, all_probs] input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs]
input_info_tensor = torch.cuda.LongTensor(input_info) input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.broadcast(input_info_tensor, 0) torch.distributed.broadcast(input_info_tensor, 0)
...@@ -125,7 +125,7 @@ def receive_generate_info(): ...@@ -125,7 +125,7 @@ def receive_generate_info():
torch.distributed.broadcast(input_info_tensor, 0) torch.distributed.broadcast(input_info_tensor, 0)
batch_size = input_info_tensor[0].item() batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item() seq_len = input_info_tensor[1].item()
max_len = input_info_tensor[2].item() tokens_to_generate = input_info_tensor[2].item()
all_probs = input_info_tensor[3].item() all_probs = input_info_tensor[3].item()
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
...@@ -135,16 +135,16 @@ def receive_generate_info(): ...@@ -135,16 +135,16 @@ def receive_generate_info():
torch.distributed.broadcast(context_length_tensor, 0) torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0) torch.distributed.broadcast(context_tokens_tensor, 0)
return context_length_tensor, context_tokens_tensor, max_len, all_probs return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs
def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor, context_length_tensor,
attention_mask, position_ids, attention_mask, position_ids,
max_len, tokens_to_generate,
all_probs) all_probs)
for tokens, lengths, output_logits, full_logits in batch_token_iterator: for tokens, lengths, output_logits, full_logits in batch_token_iterator:
context_length += 1 context_length += 1
...@@ -175,15 +175,15 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len ...@@ -175,15 +175,15 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
if tokens is not None: if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits return tokens[:, :context_length], output_logits, full_logits
def generate(model, sentences=None, max_len=0, all_probs=False): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False):
model.eval() model.eval()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
else: else:
context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info() context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()
output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
if output is not None: if output is not None:
decode_tokens, output_logits, full_logits = output decode_tokens, output_logits, full_logits = output
...@@ -264,7 +264,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -264,7 +264,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
def sample_sequence_batch(model, context_tokens, context_lengths, def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids, attention_mask, position_ids,
maxlen, all_probs=False, type_ids=None): tokens_to_generate, all_probs=False, type_ids=None):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -280,7 +280,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -280,7 +280,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
eos_id = tokenizer.eod eos_id = tokenizer.eod
counter = 0 counter = 0
org_context_length = context_length
layer_past = None layer_past = None
batch_size = context_tokens.size(0) batch_size = context_tokens.size(0)
...@@ -288,8 +287,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -288,8 +287,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
tokens = context_tokens tokens = context_tokens
output_logits = None output_logits = None
# TODO(rprenger) maxlen should be named a different parameter # Generate enough tokens for the longest sequence
maxlen = maxlen + org_context_length maxlen = tokens_to_generate + context_lengths.max().item()
# TODO(rprenger) Need a better understanding of what args.seq_length vs args.out_seq_length (shouldn't be "args") # TODO(rprenger) Need a better understanding of what args.seq_length vs args.out_seq_length (shouldn't be "args")
if maxlen > args.seq_length: if maxlen > args.seq_length:
...@@ -357,7 +356,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -357,7 +356,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if all_probs: if all_probs:
full_logits = torch.cat([full_logits, output_context], 1) full_logits = torch.cat([full_logits, output_context], 1)
#output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1)
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group) torch.distributed.broadcast(new_tokens, src, group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment