Commit 3fe6821a authored by Ryan Prenger's avatar Ryan Prenger Committed by Jared Casper
Browse files

Adding API server

parent 136d63cb
#!/bin/bash
# This example will start serving the 345M model.
DISTRIBUTED_ARGS="--nproc_per_node 1 \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
CHECKPOINT=<Path to checkpoint (e.g /345m)>
VOCAB_FILE=<Path to vocab.json (e.g. /gpt2-vocab.json)>
MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>
pip install flask-restful
python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py /
--tensor-model-parallel-size 1 /
--pipeline-model-parallel-size 1 /
--num-layers 24 /
--hidden-size 1024 /
--load ${CHECKPOINT} /
--num-attention-heads 16 /
--max-position-embeddings 1024 /
--tokenizer-type GPT2BPETokenizer /
--fp16 /
--micro-batch-size 1 /
--seq-length 1024 /
--out-seq-length 1024 /
--temperature 1.0 /
--vocab-file $VOCAB_FILE /
--merge-file $MERGE_FILE /
--top_p 0.9 /
--seed 42
#!/bin/bash
# This example will start serving the 345M model that is partitioned 8 way tensor parallel
DISTRIBUTED_ARGS="--nproc_per_node 8 \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
CHECKPOINT=<Path to checkpoint (e.g /345m)>
VOCAB_FILE=<Path to vocab.json (e.g. /gpt2-vocab.json)>
MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>
pip install flask-restful
python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py /
--tensor-model-parallel-size 8 /
--pipeline-model-parallel-size 1 /
--num-layers 24 /
--hidden-size 1024 /
--load ${CHECKPOINT} /
--num-attention-heads 16 /
--max-position-embeddings 1024 /
--tokenizer-type GPT2BPETokenizer /
--fp16 /
--micro-batch-size 1 /
--seq-length 1024 /
--out-seq-length 1024 /
--temperature 1.0 /
--vocab-file $VOCAB_FILE /
--merge-file $MERGE_FILE /
--top_p 0.9 /
--seed 42
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import numpy as np import numpy as np
import torch import torch
from datetime import timedelta
from megatron import fused_kernels from megatron import fused_kernels
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
...@@ -175,8 +176,8 @@ def _initialize_distributed(): ...@@ -175,8 +176,8 @@ def _initialize_distributed():
# Call the init process # Call the init process
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.distributed_backend, backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank) world_size=args.world_size, rank=args.rank,
timeout=timedelta(days=7))
# Set the tensor model-parallel, pipeline model-parallel, and # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators. # data-parallel communicators.
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
from megatron import mpu
from megatron.text_generation_utils import generate
GENERATE_NUM = 0
class MegatronGenerate(Resource):
def __init__(self, model):
self.model = model
@staticmethod
def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
def put(self):
args = get_args()
sentences = request.get_json()["sentences"]
if len(sentences) > 128:
return "Maximum number of sentences is 128", 400
max_len = 64 # Choosing hopefully sane default. Full sequence is slow
if "max_len" in request.get_json():
max_len = request.get_json()["max_len"]
if not isinstance(max_len, int):
return "max_len must be an integer greater than 0"
if max_len < 1:
return "max_len must be an integer greater than 0"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences = generate(self.model, sentences, max_len)
return jsonify({"sentences": resp_sentences})
def index():
return current_app.send_static_file('index.html')
class MegatronServer(object):
def __init__(self, model):
self.app = Flask(__name__)
self.app.add_url_rule('/', 'index', index)
api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])
def run(self, url):
self.app.run(url, threaded=False, debug=False)
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import socket
from megatron import get_args
from megatron import print_rank_0
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.text_generation_server import MegatronServer
from megatron.text_generation_utils import generate
import torch
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
return model
def add_text_generate_args(parser):
group = parser.add_argument_group(title='text generation')
group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
return parser
if __name__ == "__main__":
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'no_load_rng': True,
'no_load_optim': True})
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint
model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
server = MegatronServer(model)
server.run("0.0.0.0")
while True:
choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
if choice[0].item() == 0:
generate(model)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
import urllib2
class PutRequest(urllib2.Request):
'''class to handling putting with urllib2'''
def get_method(self, *args, **kwargs):
return 'PUT'
if __name__ == "__main__":
url = sys.argv[1]
while True:
sentence = raw_input("Enter prompt: ")
max_len = int(input("Enter number tokens output: "))
data = json.dumps({"sentences": [sentence], "max_len":max_len})
req = PutRequest(url, data, {'Content-Type': 'application/json'})
response = urllib2.urlopen(req)
resp_sentences = json.load(response)
print("Megatron Response: ")
print(resp_sentences["sentences"][0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment