Commit 41276b6c authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into nmt-main

parents a44360ed fc7f4f03
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality.""" """Main tasks functionality."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model evaluation""" """Model evaluation"""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run multi-stage dialogue prompting (MSDP).""" """Run multi-stage dialogue prompting (MSDP)."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets""" """Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompting the pretrained language model to generate knowledge/response""" """Prompting the pretrained language model to generate knowledge/response"""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality.""" """Main tasks functionality."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA dataset.""" """ORQA dataset."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities.""" """Evaluation utilities."""
from collections import OrderedDict from collections import OrderedDict
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA finetuning/evaluation.""" """ORQA finetuning/evaluation."""
...@@ -63,7 +50,7 @@ def orqa(Dataset): ...@@ -63,7 +50,7 @@ def orqa(Dataset):
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator', log_level=2).start()
try: try:
batch_ = next(batch) batch_ = next(batch)
except BaseException: except BaseException:
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Data Loader for Google NQ dataset Data Loader for Google NQ dataset
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Race.""" """Race."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation.""" """Vision-classification finetuning/evaluation."""
...@@ -68,7 +55,7 @@ def classification(): ...@@ -68,7 +55,7 @@ def classification():
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers("batch generator").start() timers("batch generator", log_level=2).start()
try: try:
batch_ = next(batch) batch_ = next(batch)
except BaseException: except BaseException:
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities.""" """Evaluation utilities."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetune utilities.""" """Finetune utilities."""
...@@ -136,7 +123,7 @@ def _train( ...@@ -136,7 +123,7 @@ def _train(
report_memory_flag = True report_memory_flag = True
# For each remaining epoch # For each remaining epoch
timers("interval-time").start() timers("interval-time", log_level=0).start(barrier=True)
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
print_rank_0("working on epoch {} ...".format(epoch + 1)) print_rank_0("working on epoch {} ...".format(epoch + 1))
...@@ -218,7 +205,7 @@ def finetune( ...@@ -218,7 +205,7 @@ def finetune(
timers = get_timers() timers = get_timers()
# Train and validation data loaders. # Train and validation data loaders.
timers("train/valid/test dataset/dataloder").start() timers("train/valid/test dataset/dataloder", log_level=0).start()
if args.epochs > 0: if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider() train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders( train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
...@@ -227,14 +214,14 @@ def finetune( ...@@ -227,14 +214,14 @@ def finetune(
timers("train/valid/test dataset/dataloder").stop() timers("train/valid/test dataset/dataloder").stop()
# Build calback function. # Build calback function.
timers("callback function").start() timers("callback function", log_level=0).start()
end_of_epoch_callback = None end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None: if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider() end_of_epoch_callback = end_of_epoch_callback_provider()
timers("callback function").stop() timers("callback function").stop()
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start() timers("model and optimizer", log_level=0).start()
model, optimizer, opt_param_scheduler = \ model, optimizer, opt_param_scheduler = \
setup_model_and_optimizer( setup_model_and_optimizer(
model_provider, model_provider,
...@@ -246,7 +233,7 @@ def finetune( ...@@ -246,7 +233,7 @@ def finetune(
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained # any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint. # checkpoint.
timers("pretrained checkpoint").start() timers("pretrained checkpoint", log_level=0).start(barrier=True)
if args.iteration == 0 and args.pretrained_checkpoint is not None: if args.iteration == 0 and args.pretrained_checkpoint is not None:
if args.pretrained_checkpoint_type == 'default': if args.pretrained_checkpoint_type == 'default':
original_load = args.load original_load = args.load
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality.""" """Main tasks functionality."""
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation.""" """Vision-classification finetuning/evaluation."""
...@@ -123,7 +110,7 @@ def segmentation(): ...@@ -123,7 +110,7 @@ def segmentation():
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers("batch generator").start() timers("batch generator", log_level=2).start()
import types import types
if isinstance(batch, types.GeneratorType): if isinstance(batch, types.GeneratorType):
batch_ = next(batch) batch_ = next(batch)
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation.""" """Vision-classification finetuning/evaluation."""
...@@ -86,7 +73,7 @@ def segmentation(): ...@@ -86,7 +73,7 @@ def segmentation():
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers("batch generator").start() timers("batch generator", log_level=2).start()
import types import types
if isinstance(batch, types.GeneratorType): if isinstance(batch, types.GeneratorType):
batch_ = next(batch) batch_ = next(batch)
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import einops import einops
import torch import torch
......
# coding=utf-8 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import einops import einops
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment