# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import torch import json import threading from flask import Flask, request, jsonify, current_app from flask_restful import Resource, Api from megatron import get_args from megatron.text_generation import generate_and_post_process GENERATE_NUM = 0 lock = threading.Lock() class MegatronGenerate(Resource): def __init__(self, model): self.model = model @staticmethod def send_do_generate(): choice = torch.cuda.LongTensor([GENERATE_NUM]) torch.distributed.broadcast(choice, 0) def put(self): args = get_args() print("request IP: " + str(request.remote_addr)) print(json.dumps(request.get_json()),flush=True) print("current time: ", datetime.datetime.now()) if not "prompts" in request.get_json(): return "prompts argument required", 400 if "max_len" in request.get_json(): return "max_len is no longer used. Replace with tokens_to_generate", 400 if "sentences" in request.get_json(): return "sentences is no longer used. Replace with prompts", 400 prompts = request.get_json()["prompts"] if len(prompts) > 128: return "Maximum number of prompts is 128", 400 tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow just_score=False if "tokens_to_generate" in request.get_json(): tokens_to_generate = request.get_json()["tokens_to_generate"] if not isinstance(tokens_to_generate, int): return "tokens_to_generate must be an integer greater than 0" if tokens_to_generate < 0: return "tokens_to_generate must be an integer greater than or equal to 0" if tokens_to_generate == 0: just_score = True logprobs = False if "logprobs" in request.get_json(): logprobs = request.get_json()["logprobs"] if not isinstance(logprobs, bool): return "logprobs must be a boolean value" if just_score and not logprobs: return "tokens_to_generate=0 implies logprobs=True" temperature = 1.0 if "temperature" in request.get_json(): temperature = request.get_json()["temperature"] if not (type(temperature) == int or type(temperature) == float): return "temperature must be a positive number less than or equal to 100.0" if not (0.0 < temperature <= 100.0): return "temperature must be a positive number less than or equal to 100.0" top_k = 0.0 if "top_k" in request.get_json(): top_k = request.get_json()["top_k"] if not (type(top_k) == int): return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000" if not (0 <= top_k <= 1000): return "top_k must be equal to or greater than 0 and less than or equal to 1000" top_p = 0.0 if "top_p" in request.get_json(): top_p = request.get_json()["top_p"] if not (type(top_p) == float): return "top_p must be a positive float less than or equal to 1.0" if top_p > 0.0 and top_k > 0.0: return "cannot set both top-k and top-p samplings." if not (0 <= top_p <= 1.0): return "top_p must be less than or equal to 1.0" add_BOS = False if "add_BOS" in request.get_json(): add_BOS = request.get_json()["add_BOS"] if not isinstance(add_BOS, bool): return "add_BOS must be a boolean value" with lock: # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate response, response_seg, response_logprobs, _ = \ generate_and_post_process( self.model, prompts=prompts, tokens_to_generate=tokens_to_generate, return_output_log_probs=logprobs, greedy_sampling=args.greedy, top_k_sampling=top_k, top_p_sampling=top_p, temperature=temperature, add_BOS=add_BOS, use_eod_token_for_early_termination=True, just_score=just_score) return jsonify({"text": response, "segments": response_seg, "logprobs": response_logprobs}) class MegatronServer(object): def __init__(self, model): self.app = Flask(__name__, static_url_path='') api = Api(self.app) api.add_resource(MegatronGenerate, '/api', resource_class_args=[model]) def run(self, url): self.app.run(url, threaded=True, debug=False)