Commit 67aa8619 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into checkpoint_util

parents 03d09af0 f5345dfa
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup for pip package."""
import os
import sys
import setuptools
if sys.version_info < (3,):
raise Exception("Python 2 is not supported by Megatron.")
from megatron.package_info import (
__description__,
__contact_names__,
__url__,
__download_url__,
__keywords__,
__license__,
__package_name__,
__version__,
)
with open("README.md", "r") as fh:
long_description = fh.read()
###############################################################################
# Dependency Loading #
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
def req_file(filename):
with open(filename) as f:
content = f.readlines()
return [x.strip() for x in content]
install_requires = req_file("requirements.txt")
setuptools.setup(
name=__package_name__,
# Versions should comply with PEP440. For a discussion on single-sourcing
# the version across setup.py and the project code, see
# https://packaging.python.org/en/latest/single_source_version.html
version=__version__,
description=__description__,
long_description=long_description,
long_description_content_type="text/markdown",
# The project's main homepage.
url=__url__,
author=__contact_names__,
maintainer=__contact_names__,
# The licence under which the project is released
license=__license__,
classifiers=[
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'Intended Audience :: Information Technology',
# Indicate what your project relates to
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development :: Libraries :: Python Modules',
# Supported python versions
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
# Additional Setting
'Environment :: Console',
'Natural Language :: English',
'Operating System :: OS Independent',
],
python_requires='>=3.6',
packages=setuptools.find_packages(),
install_requires=install_requires,
# Add in any packaged data.
include_package_data=True,
zip_safe=False,
# PyPI package information.
keywords=__keywords__
)
...@@ -205,7 +205,7 @@ def main(): ...@@ -205,7 +205,7 @@ def main():
args.task)) args.task))
# Set up model and load checkpoint. # Set up model and load checkpoint.
model = get_model(get_model_provider(eval_metric)) model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False)
if args.load is not None: if args.load is not None:
_ = load_checkpoint(model, None, None) _ = load_checkpoint(model, None, None)
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file
from megatron.text_generation_utils import generate_samples_interactive
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=False,
pre_process=pre_process, post_process=post_process)
return model
def add_text_generate_args(parser):
"""Text generation arguments."""
group = parser.add_argument_group(title='text generation')
group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='Output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='Number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='Output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='During generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
def main():
"""Main program."""
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'no_load_rng': True,
'no_load_optim': True})
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint.
model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# Generate samples.
if args.num_samples == 0:
args.micro_batch_size = 1
if args.sample_input_file != None:
generate_samples_input_from_file(model)
else:
generate_samples_interactive(model)
else:
generate_and_write_samples_unconditional(model)
if __name__ == "__main__":
main()
...@@ -27,7 +27,7 @@ from megatron.initialize import initialize_megatron ...@@ -27,7 +27,7 @@ from megatron.initialize import initialize_megatron
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_server import MegatronServer from megatron.text_generation_server import MegatronServer
from megatron.text_generation_utils import generate from megatron.text_generation import generate_and_post_process
import torch import torch
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
...@@ -43,8 +43,6 @@ def add_text_generate_args(parser): ...@@ -43,8 +43,6 @@ def add_text_generate_args(parser):
group.add_argument("--temperature", type=float, default=1.0, group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.') help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0, group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.') help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0, group.add_argument("--top_k", type=int, default=0,
...@@ -65,7 +63,7 @@ if __name__ == "__main__": ...@@ -65,7 +63,7 @@ if __name__ == "__main__":
print("Interleaved pipeline schedule is not yet supported for text generation.") print("Interleaved pipeline schedule is not yet supported for text generation.")
exit() exit()
# Set up model and load checkpoint # Set up model and load checkpoint
model = get_model(model_provider) model = get_model(model_provider, wrap_with_ddp=False)
if args.load is not None: if args.load is not None:
_ = load_checkpoint(model, None, None) _ = load_checkpoint(model, None, None)
...@@ -78,8 +76,6 @@ if __name__ == "__main__": ...@@ -78,8 +76,6 @@ if __name__ == "__main__":
while True: while True:
choice = torch.cuda.LongTensor(1) choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, torch.distributed.broadcast(choice, 0)
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
if choice[0].item() == 0: if choice[0].item() == 0:
generate(model) generate_and_post_process(model)
...@@ -25,10 +25,10 @@ if __name__ == "__main__": ...@@ -25,10 +25,10 @@ if __name__ == "__main__":
url = sys.argv[1] url = sys.argv[1]
while True: while True:
sentence = raw_input("Enter prompt: ") sentence = raw_input("Enter prompt: ")
max_len = int(input("Enter number tokens output: ")) tokens_to_generate = int(input("Enter number of tokens to generate: "))
data = json.dumps({"sentences": [sentence], "max_len":max_len}) data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
req = PutRequest(url, data, {'Content-Type': 'application/json'}) req = PutRequest(url, data, {'Content-Type': 'application/json'})
response = urllib2.urlopen(req) response = urllib2.urlopen(req)
resp_sentences = json.load(response) resp_sentences = json.load(response)
print("Megatron Response: ") print("Megatron Response: ")
print(resp_sentences["sentences"][0]) print(resp_sentences["text"][0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment