Commit 06fc51ce authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into checkpoint_util

parents ec561daa 0ed2f6ac
......@@ -29,12 +29,13 @@ The following applies to all files unless otherwise noted:
--
This repository also contains code from Hugging Face Inc., Google Research,
Facebook (from their Fairseq project), and Philip Popien. Files from these
organizations have notices at the top of each file. Below are licenses
used in those files, as indicated.
Facebook (from their Fairseq and Dino projects), Microsoft(from their
Swin-Transformer project)and Philip Popien. Files from these
organizations have notices at the top of each file. Below are
licenses used in those files, as indicated.
------------- LICENSE FOR huggingface and Google Research code --------------
------------- LICENSE FOR Facebook, huggingface and Google Research code --------------
Apache License
......@@ -263,3 +264,113 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
------------- LICENSE FOR Mircrosoft Swin transformer code --------------
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
--------------- NVIDIA Source Code License for SegFormer -----------------
1. Definitions
“Licensor” means any person or entity that distributes its Work.
“Software” means the original work of authorship made available under this
License.
“Work” means the Software and any additions to or derivative works of the
Software that are made available under this License.
The terms “reproduce,” “reproduction,” “derivative works,” and
“distribution” have the meaning as provided under U.S. copyright law;
provided, however, that for the purposes of this License, derivative works
shall not include works that remain separable from, or merely link
(or bind by name) to the interfaces of, the Work.
Works, including the Software, are “made available” under this License by
including in or with the Work either (a) a copyright notice referencing
the applicability of this License to the Work, or (b) a copy of this License.
2. License Grant
2.1 Copyright Grant. Subject to the terms and conditions of this License,
each Licensor grants to you a perpetual, worldwide, non-exclusive,
royalty-free, copyright license to reproduce, prepare derivative works of,
publicly display, publicly perform, sublicense and distribute its Work
and any resulting derivative works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only if
(a) you do so under this License, (b) you include a complete copy of this
License with your distribution, and (c) you retain without modification any
copyright, patent, trademark, or attribution notices that are present
in the Work.
3.2 Derivative Works. You may specify that additional or different terms
apply to the use, reproduction, and distribution of your derivative works
of the Work (“Your Terms”) only if (a) Your Terms provide that the use
limitation in Section 3.3 applies to your derivative works, and (b) you
identify the specific derivative works that are subject to Your Terms.
Notwithstanding Your Terms, this License (including the redistribution
requirements in Section 3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only may
be used or intended for use non-commercially. Notwithstanding the
foregoing, NVIDIA and its affiliates may use the Work and any derivative
works commercially. As used herein, “non-commercially” means for research
or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim against
any Licensor (including any claim, cross-claim or counterclaim in a lawsuit)
to enforce any patents that you allege are infringed by any Work, then
your rights under this License from such Licensor (including the grant
in Section 2.1) will terminate immediately.
3.5 Trademarks. This License does not grant any rights to use any Licensor’s
or its affiliates’ names, logos, or trademarks, except as necessary to
reproduce the notices described in this License.
3.6 Termination. If you violate any term of this License, then your rights
under this License (including the grant in Section 2.1) will terminate
immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT.
YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT
OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf) and [2](https://arxiv.org/pdf/2104.04473.pdf)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor and pipeline), and multi-node pre-training oftransformer based models such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf), and [T5](https://arxiv.org/abs/1910.10683) using mixed precision.
Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf) and [2](https://arxiv.org/pdf/2104.04473.pdf)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor and pipeline), and multi-node pre-training of transformer based models such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf), and [T5](https://arxiv.org/abs/1910.10683) using mixed precision.
Below are some of the projects where we have directly used Megatron:
* [BERT and GPT Studies Using Megatron](https://arxiv.org/pdf/1909.08053.pdf)
......@@ -11,7 +11,11 @@ Below are some of the projects where we have directly used Megatron:
* [Scaling Language Model Training to a Trillion Parameters Using Megatron](https://arxiv.org/pdf/2104.04473.pdf)
* [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf)
Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. The table below shows the model configurations along with the achieved FLOPs (both per GPU and aggregate over all GPUs). Note that these results are from benchmark runs and these models were not trained to convergence; however, the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.
Megatron is also used in [NeMo Megatron](https://developer.nvidia.com/nvidia-nemo#nemo-megatron), a framework to help enterprises overcome the challenges of building and training sophisticated natural language processing models with billions and trillions of parameters.
Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. Each cluster node has 8 NVIDIA 80GB A100 GPUs. The table below shows the model configurations along with the achieved FLOPs (both per GPU and aggregate over all GPUs). Note that these results are from benchmark runs and these models were not trained to convergence; however, the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.
Additionally, the model parallel size column reports a combined tensor and pipeline parallelism degrees. For numbers larger than 8, typically tensor parallel of size 8 was used. So, for example, the 145B model reports the total model parallel size of 64, which means that this setup used TP=8 and PP=8.
![Cases](images/cases_april2021.png)
......@@ -27,7 +31,6 @@ All the cases from 1 billion to 1 trillion parameters achieve more than 43% half
* [Data Preprocessing](#data-preprocessing)
* [BERT Pretraining](#bert-pretraining)
* [GPT Pretraining](#gpt-pretraining)
* [GPT Pretraining](#gpt-pretraining)
* [T5 Pretraining](#t5-pretraining)
* [Distributed Pretraining](#distributed-pretraining)
* [GPT-3 Example](#gpt-3-example)
......@@ -44,9 +47,7 @@ All the cases from 1 billion to 1 trillion parameters achieve more than 43% half
* [Collecting GPT Webtext Data](#collecting-gpt-webtext-data)
# Setup
We have tested Megatron with [NGC's PyTorch container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) version 20.12, which uses python 3.8, pytorch 1.8, cuda 11.1, and nccl 2.8.3.
To use this repository, please install the latest supported versions of PyTorch with GPU support (python 3.8, pytorch 1.8, cuda 11.1, and nccl 2.8.3 and above) and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start). We strongly recommend using one of [NGC's recent PyTorch containers](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) (the latest compatible version at time of publication can be pulled with `docker pull nvcr.io/nvidia/pytorch:20.12-py3`). Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks.
We strongly recommend using the latest release of [NGC's PyTorch container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch). If you can't use this for some reason, use the latest pytorch, cuda, nccl, and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start) releases. Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks.
## Downloading Checkpoints
We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_bert_345m) and [GPT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_lm_345m) checkpoints for use to evaluate or finetuning downstream tasks. To access these checkpoints, first [sign up](https://ngc.nvidia.com/signup) for and [setup](https://ngc.nvidia.com/setup/installers/cli) the NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in the [NGC documentation](https://docs.nvidia.com/dgx/ngc-registry-cli-user-guide/index.html#topic_6_4_1).
......@@ -204,7 +205,7 @@ Further command line arguments are described in the source file [`arguments.py`]
## T5 Pretraining
Very similar to BERT and GPT, the `examples/pretrain_t5.sh` script runs single GPU "base" (~220M parameter) T5 pretraining. The primary difference from BERT and GPT is the addition of the following arguments to accomodate the T5 architecture:
Very similar to BERT and GPT, the `examples/pretrain_t5.sh` script runs single GPU "base" (~220M parameter) T5 pretraining. The primary difference from BERT and GPT is the addition of the following arguments to accommodate the T5 architecture:
* `--kv-channels` sets the inner dimension of the "key" and "value" matrices of all attention mechanisms in the model. For BERT and GPT this defaults to the hidden size divided by the number of attention heads, but can be configured for T5.
......@@ -260,7 +261,7 @@ Second, we developed a simple and efficient two-dimensional model-parallel appro
<!-- The number of microbatches in a per-pipeline minibatch is controlled by the `--num-microbatches-in-minibatch` argument. With `WORLD_SIZE` GPUs, `TENSOR_MP_SIZE` tensor-model-parallel size, `PIPELINE_MP_SIZE` pipeline-model-parallel-size, `WORLD_SIZE`/(`TENSOR_MP_SIZE` * `PIPELINE_MP_SIZE`) GPUs will be used for data parallelism. The default values for `--tensor-model-parallel-size` and `--pipeline-model-parallel-size` is 1, which will not implement either form of model parallelism. -->
We have examples of how to use these two different forms of model parallelism the example scripts ending in `distributed_with_mp.sh`, note that pipeline parallelism is not currently supported in the T5 model:
We have examples of how to use these two different forms of model parallelism the example scripts ending in `distributed_with_mp.sh`:
Other than these minor changes, the distributed training is identical to the training on a single GPU.
......@@ -468,7 +469,7 @@ python tasks/main.py \
### LAMBADA Cloze Accuracy
To compute LAMBADA cloze accuracy (the accuracy of predicting the last token given the preceeding tokens) we utilize a detokenized, processed version of the [LAMBADA dataset](https://github.com/cybertronai/bflm/blob/master/lambada_test.jsonl).
To compute LAMBADA cloze accuracy (the accuracy of predicting the last token given the preceding tokens) we utilize a detokenized, processed version of the [LAMBADA dataset](https://github.com/cybertronai/bflm/blob/master/lambada_test.jsonl).
We use the following command to run LAMBADA evaluation on a 345M parameter model. Note that the `--strict-lambada` flag should be used to require whole word matching. Make that `lambada` is part of the file path.
......
# Multi-Stage Prompting for Knowledgeable Dialogue Generation
This directory contains all the scripts of multi-stage prompting for knowledgeable dialogue generation that includes data preparation, and knowledge and response generations. More details are available on [`knowledgeable task directory`](../../tasks/msdp).
#!/bin/bash
# Data preparation for our framework: preprocessing the WoW and WoI datasets
# The datasets can be downloaded through the following links:
# WoW: https://parl.ai/projects/wizard_of_wikipedia/
# WoI: https://parl.ai/projects/sea/
DIR=`pwd`
# Before running the preprocessing, please download
# the wizard of wikipedia and wizard datasets
WOW_DATA_FOLDER=<PATH_OF_WIZARD_OF_WIKIPEDIA_DATA_FOLDER>
WOI_DATA_FOLDER=<PATH_OF_WIZARD_OF_INTERNET_DATA_FOLDER>
# We provide examples for processing the raw data from Wizard of Wikipedia
# Processing the train dataset (train.json)
python ${DIR}/tasks/msdp/preprocessing.py \
--func process_wow_dataset \
--raw_file ${WOW_DATA_FOLDER}/train.json \
--processed_file ${WOW_DATA_FOLDER}/train_processed.txt
# Processing test seen dataset (test_random_split.json)
python ${DIR}/tasks/msdp/preprocessing.py \
--func process_wow_dataset \
--raw_file ${WOW_DATA_FOLDER}/test_random_split.json \
--processed_file ${WOW_DATA_FOLDER}/testseen_processed.txt \
--knwl_ref_file ${WOW_DATA_FOLDER}/output_testseen_knowledge_reference.txt \
--resp_ref_file ${WOW_DATA_FOLDER}/output_testseen_response_reference.txt
# processing test unseen dataset (test_topic_split.json)
python ${DIR}/tasks/msdp/preprocessing.py \
--func process_wow_dataset \
--raw_file ${WOW_DATA_FOLDER}/test_topic_split.json \
--processed_file ${WOW_DATA_FOLDER}/testunseen_processed.txt \
--knwl_ref_file ${WOW_DATA_FOLDER}/output_testunseen_knowledge_reference.txt \
--resp_ref_file ${WOW_DATA_FOLDER}/output_testunseen_response_reference.txt
# We provide the following script to process the raw data from Wizard of Internet
# Processing the test dataset (test.jsonl)
python ${DIR}/tasks/msdp/preprocessing.py \
--func process_woi_dataset \
--raw_file ${WOI_DATA_FOLDER}/test.jsonl \
--processed_file ${WOI_DATA_FOLDER}/test_processed.txt \
--knwl_ref_file ${WOI_DATA_FOLDER}/output_test_knowledge_reference.txt \
--resp_ref_file ${WOI_DATA_FOLDER}/output_test_response_reference.txt
# Get the knowledge generation prompts for the each test dataset in WoW and WoI
MODEL_FILE=<PATH_OF_THE_FINETUNED_DPR_MODEL>
# WoW test seen
python ${DIR}/tasks/msdp/preprocessing.py \
--func get_knwl_gen_prompts \
--test_file ${WOW_DATA_FOLDER}/testseen_processed.txt \
--train_file ${WOW_DATA_FOLDER}/train_processed.txt \
--model_file ${MODEL_FILE} \
--processed_file ${WOW_DATA_FOLDER}/output_testseen_knowledge_prompts.json \
--data_type wow_seen
# WoW test unseen
python ${DIR}/tasks/msdp/preprocessing.py \
--func get_knwl_gen_prompts \
--test_file ${WOW_DATA_FOLDER}/testunseen_processed.txt \
--train_file ${WOW_DATA_FOLDER}/train_processed.txt \
--model_file ${MODEL_FILE} \
--processed_file ${WOW_DATA_FOLDER}/output_testunseen_knowledge_prompts.json \
--data_type wow_unseen
# WoI
python ${DIR}/tasks/msdp/preprocessing.py \
--func get_knwl_gen_prompts \
--test_file ${WOI_DATA_FOLDER}/test_processed.txt \
--train_file ${WOW_DATA_FOLDER}/train_processed.txt \
--model_file ${MODEL_FILE} \
--processed_file ${WOI_DATA_FOLDER}/output_test_knowledge_prompts.json \
--data_type woi
# Get the response generation prompts (can be applied for all the test datasets)
python ${DIR}/tasks/msdp/preprocessing.py \
--func get_resp_gen_prompts \
--train_file ${WOW_DATA_FOLDER}/train_processed.txt \
--processed_file ${WOW_DATA_FOLDER}/output_response_prompts.txt
#!/bin/bash
#########################
# Evaluate the F1 scores.
#########################
WORLD_SIZE=1
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
MODEL_GEN_PATH=<PATH_OF_THE_KNOWLEDGE_GENERATION> \
(e.g., /testseen_knowledge_generations.txt)
GROUND_TRUTH_PATH=<PATH_OF_THE_GROUND_TRUTH_KNOWLEDGE> \
(e.g., /testseen_knowledge_reference.txt)
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/msdp/main.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--task MSDP-EVAL-F1 \
--guess-file ${MODEL_GEN_PATH} \
--answer-file ${GROUND_TRUTH_PATH}
############################################
# Evaluate BLEU, METEOR, and ROUGE-L scores.
############################################
# We follow the nlg-eval (https://github.com/Maluuba/nlg-eval) to
# evaluate the BLEU, METEOR, and ROUGE-L scores.
# To evaluate on these metrics, please setup the environments based on
# the nlg-eval github, and run the corresponding evaluation commands.
nlg-eval \
--hypothesis=<PATH_OF_THE_KNOWLEDGE_GENERATION> \
--references=<PATH_OF_THE_GROUND_TRUTH_KNOWLEDGE>
#!/bin/bash
#########################
# Evaluate the F1 scores.
#########################
WORLD_SIZE=1
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
MODEL_GEN_PATH=<PATH_OF_THE_RESPONSE_GENERATION> \
(e.g., /testseen_response_generations.txt)
GROUND_TRUTH_PATH=<PATH_OF_THE_GROUND_TRUTH_RESPONSE> \
(e.g., /testseen_response_reference.txt)
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/msdp/main.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--task MSDP-EVAL-F1 \
--guess-file ${MODEL_GEN_PATH} \
--answer-file ${GROUND_TRUTH_PATH}
##########################
# Evaluate the KF1 scores.
##########################
MODEL_GEN_PATH=<PATH_OF_THE_RESPONSE_GENERATION> \
(e.g., /testseen_response_generations.txt)
GROUND_TRUTH_PATH=<PATH_OF_THE_GROUND_TRUTH_KNOWLEDGE> \
(e.g., /testseen_knowledge_reference.txt)
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/msdp/main.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--task MSDP-EVAL-F1 \
--guess-file ${MODEL_GEN_PATH} \
--answer-file ${GROUND_TRUTH_PATH}
############################################
# Evaluate BLEU, METEOR, and ROUGE-L scores.
############################################
# We follow the nlg-eval (https://github.com/Maluuba/nlg-eval) to
# evaluate the BLEU, METEOR, and ROUGE-L scores.
# To evaluate on these metrics, please setup the environments based on
# the nlg-eval github, and run the corresponding evaluation commands.
nlg-eval \
--hypothesis=<PATH_OF_THE_RESPONSE_GENERATION> \
--references=<PATH_OF_THE_GROUND_TRUTH_RESPONSE>
#!/bin/bash
# Preparing the input file for the response generation (second-stage prompting)
DIR=`pwd`
TEST_FILE=<PATH_OF_PROCESSED_TEST_DATA> \
(e.g., /testseen_processed.txt)
KNOWLEDGE_FILE=<PATH_OF_GENERATED_KNOWLEDGE_DATA> \
(e.g., /testseen_knowledge_generations.txt)
PROCESSED_FILE=<PATH_OF_INPUT_FILE_FOR_RESPONSE_GENERATION> \
(e.g., /testseen_processed_with_generated_knowledge.txt)
python ${DIR}/tasks/msdp/preprocessing.py \
--func prepare_input \
--test_file ${TEST_FILE} \
--knwl_gen_file ${KNOWLEDGE_FILE} \
--processed_file ${PROCESSED_FILE}
#!/bin/bash
# Stage-1: Prompt a pretrained language model to generate the context-relevant knowledge
# The input contains prompts and current dialogue context, the output is the relevant knowledge
# The size of the pretrained language model is 357M
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
CHECKPOINT_PATH=<PATH_OF_LANGUAGE_MODEL> (e.g., /357m)
VOCAB_PATH=<PATH_OF_VOCAB_FILE> (e.g., /gpt2-vocab.json)
MERGE_PATH=<PATH_OF_MERGE_FILE> (e.g., /gpt2-merges.txt)
INPUT_PATH=<PATH_OF_PROCESSED_TEST_DATA_FILE> \
(e.g., /testseen_processed.txt)
PROMPT_PATH=<PATH_OF_KNOWLEDGE_GENERATION_PROMPTS> \
(e.g., /testseen_knowledge_prompts.json)
OUTPUT_PATH=<PATH_OF_OUTPUT_GENERATION_FILE> \
(e.g., /testseen_knowledge_generations.txt)
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/msdp/main.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 1 \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--load ${CHECKPOINT_PATH} \
--fp16 \
--DDP-impl torch \
--tokenizer-type GPT2BPETokenizer \
--sample-input-file ${INPUT_PATH} \
--sample-output-file ${OUTPUT_PATH} \
--prompt-file ${PROMPT_PATH} \
--prompt-type knowledge \
--num-prompt-examples 10 \
--task MSDP-PROMPT
# NOTE: If you use api for the model generation, please use
# the "--api-prompt" flag (setting this value as True).
#!/bin/bash
# Stage-2: Prompt a pretrained language model to generate the corresponding response
# The input contains prompts, current dialogue context, and generated knowledge in Stage-1
# The output is the corresponding response.
# The size of the pretrained language model is 357M
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
CHECKPOINT_PATH=<PATH_OF_LANGUAGE_MODEL> (e.g., /357m)
VOCAB_PATH=<PATH_OF_VOCAB_FILE> (e.g., /gpt2-vocab.json)
MERGE_PATH=<PATH_OF_MERGE_FILE> (e.g., /gpt2-merges.txt)
INPUT_PATH=<PATH_OF_INPUT_TEST_DATA_FILE> (e.g., /testseen_processed.txt)
PROMPT_PATH=<PATH_OF_RESPONSE_GENERATION_PROMPTS> \
(e.g., /response_prompts.txt)
OUTPUT_PATH=<PATH_OF_OUTPUT_GENERATION_FILE> \
(e.g., /output_testseen_response_generations.txt)
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/msdp/main.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 1 \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--load ${CHECKPOINT_PATH} \
--fp16 \
--DDP-impl torch \
--tokenizer-type GPT2BPETokenizer \
--sample-input-file ${INPUT_PATH} \
--sample-output-file ${OUTPUT_PATH} \
--prompt-file ${PROMPT_PATH} \
--prompt-type response \
--num-prompt-examples 20 \
--task MSDP-PROMPT
# NOTE: If you use api for the model generation, please use
# the "--api-prompt" flag (setting this value as True).
......@@ -17,6 +17,7 @@ import torch
from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
......
......@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_vision_args(parser)
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
......@@ -71,6 +71,11 @@ def validate_args(args, defaults={}):
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
......@@ -142,7 +147,7 @@ def validate_args(args, defaults={}):
'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage'
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.pipeline_model_parallel_size) // \
(args.num_layers // args.transformer_pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = None
......@@ -250,17 +255,38 @@ def validate_args(args, defaults={}):
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\
'for distributed checkpoint activations to work you '\
'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed checkpoint activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
_print_args(args)
return args
......@@ -372,6 +398,9 @@ def _add_logging_args(parser):
group.add_argument('--log-memory-to-tensorboard',
action='store_true',
help='Enable memory logging to tensorboard.')
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
return parser
......@@ -385,6 +414,13 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
......@@ -467,6 +503,9 @@ def _add_training_args(parser):
'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--exit-signal-handler', action='store_true',
help='Dynamically save the checkpoint and shutdown the '
'training if SIGTERM is received')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-masked-softmax-fusion',
......@@ -491,6 +530,11 @@ def _add_training_args(parser):
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
return parser
......@@ -500,6 +544,9 @@ def _add_initialization_args(parser):
group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
group.add_argument('--data-parallel-random-init', action='store_true',
help='Enable random initialization of params '
'across data parallel ranks')
group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.')
......@@ -540,13 +587,13 @@ def _add_learning_rate_args(parser):
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--override-lr-scheduler', action='store_true',
group.add_argument('--override-opt_param-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
......@@ -668,6 +715,11 @@ def _add_distributed_args(parser):
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--standalone-embedding-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
return parser
......@@ -814,16 +866,70 @@ def _add_biencoder_args(parser):
return parser
def _add_vit_args(parser):
group = parser.add_argument_group(title="vit")
def _add_vision_args(parser):
group = parser.add_argument_group(title="vision")
# general vision arguements
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-dim', type=int, default=224,
help='Image size for vision classification task')
group.add_argument('--img-h', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--img-w', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
help='patch dimension')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
help='training with fraction of data per class.')
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection`
group.add_argument('--vision-pretraining', action='store_true',
help='flag to indicate vision pretraining')
group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'dino'],
help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'],
help='backbone types types')
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')
# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')
# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
group.add_argument('--dino-local-img-size', type=int, default=96,
help='Image size for vision classification task')
group.add_argument('--dino-local-crops-number', type=int, default=10,
help='Number of local crops')
group.add_argument('--dino-head-hidden-size', type=int, default=2048,
help='Hidden dimension size in dino head')
group.add_argument('--dino-bottleneck-size', type=int, default=256,
help='Bottle neck dimension in dino head ')
group.add_argument('--dino-freeze-last-layer', type=float, default=1,
help='Freezing last layer weights')
group.add_argument('--dino-norm-last-layer', action='store_true',
help='Disable Norm in last layer.')
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
help='warump teacher temperature')
group.add_argument('--dino-teacher-temp', type=float, default=0.07,
help='teacher temperature')
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
help='warmup teacher temperaure epochs')
return parser
......@@ -65,6 +65,8 @@ def check_checkpoint_args(checkpoint_args):
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
if args.data_parallel_random_init:
_compare('data_parallel_random_init')
if get_checkpoint_version() < 3.0:
_compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size')
......@@ -72,7 +74,6 @@ def check_checkpoint_args(checkpoint_args):
_compare('tensor_model_parallel_size')
_compare('pipeline_model_parallel_size')
def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename)
......@@ -175,7 +176,33 @@ def read_metadata(tracker_filename):
return max_iter, release
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def get_rng_state():
""" collect rng state across data parallel ranks """
args = get_args()
rng_state = {
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}
rng_state_list = None
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
rng_state_list,
rng_state,
group=mpu.get_data_parallel_group())
else:
rng_state_list = [rng_state]
return rng_state_list
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
"""Save a model checkpoint."""
args = get_args()
......@@ -185,6 +212,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
# collect rng state across data parallel ranks
rng_state = get_rng_state()
if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model.
......@@ -203,17 +233,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
if opt_param_scheduler is not None:
state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate()
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()
state_dict["rng_state"] = rng_state
# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
......@@ -418,7 +443,8 @@ def load_args_from_checkpoint(args, load_arg='load'):
_set_arg('num_layers_per_virtual_pipeline_stage')
return args
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
......@@ -481,8 +507,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
......@@ -493,15 +522,32 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
if 'rng_state' in state_dict:
# access rng_state for data parallel rank
if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
......
......@@ -16,8 +16,10 @@
"""Dataloaders."""
import torch
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from megatron import get_args
from megatron import mpu
......@@ -39,11 +41,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
dataset,
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
data_parallel_size=mpu.get_data_parallel_world_size(),
data_sharding=args.data_sharding)
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
......@@ -103,16 +107,40 @@ class MegatronPretrainingSampler:
yield batch[start_idx:end_idx]
class RandomSeedDataset(Dataset):
def __init__(self, dataset):
args = get_args()
self.base_seed = args.seed
self.curr_seed = args.seed
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def set_epoch(self, epoch):
self.curr_seed = self.base_seed + epoch
def __getitem__(self, idx):
seed = idx + self.curr_seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
return self.dataset[idx]
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, data_sharding):
# Keep a copy of input params for later use.
self.dataset = dataset
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.data_sharding = data_sharding
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
......@@ -136,16 +164,30 @@ class MegatronPretrainingRandomSampler:
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
if isinstance(self.dataset, RandomSeedDataset):
self.dataset.set_epoch(self.epoch)
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
if self.data_sharding:
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
else:
full_bucket_size = (self.total_samples // self.micro_batch_size) \
* self.micro_batch_size
full_bucket_offset = current_epoch_samples
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_total = \
torch.randperm(full_bucket_size, generator=g).tolist()
idx_range_active = idx_range_total[full_bucket_offset:]
idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]
batch = []
# Last batch if not complete will be dropped.
......
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# code taken from
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py
# added support for classes_fraction and data_per_class_fraction
from torchvision.datasets import VisionDataset
from PIL import Image
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import numpy as np
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def is_image_file(filename: str) -> bool:
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
data_per_class_fraction: float,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
Args:
directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
instances = []
directory = os.path.expanduser(directory)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
local_instances = []
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
local_instances.append(item)
instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)])
return instances
class DatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/[...]/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/[...]/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
classes_fraction=1.0,
data_per_class_fraction=1.0,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
self.classes_fraction = classes_fraction
self.data_per_class_fraction = data_per_class_fraction
classes, class_to_idx = self._find_classes(self.root)
samples = self.make_dataset(self.root,
class_to_idx,
self.data_per_class_fraction,
extensions,
is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)
self.loader = loader
self.extensions = extensions
self.total = len(samples)
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
data_per_class_fraction: float,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
return make_dataset(directory,
class_to_idx,
data_per_class_fraction,
extensions=extensions,
is_valid_file=is_valid_file)
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
all_classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes = all_classes[0:int(len(all_classes) * self.classes_fraction)]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
curr_index = index
for x in range(self.total):
try:
path, target = self.samples[curr_index]
sample = self.loader(path)
break
except Exception as e:
curr_index = np.random.randint(0, self.total)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
classes_fraction=1.0,
data_per_class_fraction=1.0,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
classes_fraction=classes_fraction,
data_per_class_fraction=data_per_class_fraction,
is_valid_file=is_valid_file)
self.imgs = self.samples
......@@ -13,46 +13,250 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import torch
from torchvision import datasets, transforms
import torchvision.transforms as T
from torchvision import datasets
from megatron import get_args
from megatron.data.image_folder import ImageFolder
from megatron.data.autoaugment import ImageNetPolicy
from megatron.data.data_samplers import RandomSeedDataset
from PIL import Image, ImageFilter, ImageOps
def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
# training dataset
train_data_path = os.path.join(data_path[0], "train")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
process = [
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
]
if color_jitter:
process += [
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
return img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
]
fp16_t = transforms.ConvertImageDtype(torch.half)
process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
transform_train = transforms.Compose(process)
train_data = datasets.ImageFolder(
root=train_data_path, transform=transform_train
)
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class ClassificationTransform():
def __init__(self, image_size, train=True):
args = get_args()
assert args.fp16 or args.bf16
self.data_type = torch.half if args.fp16 else torch.bfloat16
if train:
self.transform = T.Compose([
T.RandomResizedCrop(image_size),
T.RandomHorizontalFlip(),
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
ImageNetPolicy(),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
else:
self.transform = T.Compose([
T.Resize(image_size),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
def __call__(self, input):
output = self.transform(input)
return output
class InpaintingTransform():
def __init__(self, image_size, train=True):
args = get_args()
self.mask_factor = args.mask_factor
self.mask_type = args.mask_type
self.image_size = image_size
self.patch_size = args.patch_dim
self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size))
self.train = train
assert args.fp16 or args.bf16
self.data_type = torch.half if args.fp16 else torch.bfloat16
if self.train:
self.transform = T.Compose([
T.RandomResizedCrop(self.image_size),
T.RandomHorizontalFlip(),
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
ImageNetPolicy(),
T.ToTensor(),
T.ConvertImageDtype(self.data_type)
])
else:
self.transform = T.Compose([
T.Resize(self.image_size, interpolation=2),
T.CenterCrop(self.image_size),
T.ToTensor(),
T.ConvertImageDtype(self.data_type)
])
def gen_mask(self, image_size, mask_size, mask_type, patch_size):
# output: mask as a list with indices for missing patches
action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
assert image_size[0] == image_size[1]
img_size_patch = image_size[0] // patch_size
# drop masked patches
mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float)
if mask_type == 'random':
x = torch.randint(0, img_size_patch, ())
y = torch.randint(0, img_size_patch, ())
for i in range(mask_size):
r = torch.randint(0, len(action_list), ())
x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1)
y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1)
x_offset = x * patch_size
y_offset = y * patch_size
mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
else:
assert mask_type == 'row'
count = 0
for x in reversed(range(img_size_patch)):
for y in reversed(range(img_size_patch)):
if (count < mask_size):
count += 1
x_offset = x * patch_size
y_offset = y * patch_size
mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
return mask
def __call__(self, input):
trans_input = self.transform(input)
mask = self.gen_mask(self.image_size, self.mask_size,
self.mask_type, self.patch_size)
mask = mask.unsqueeze(dim=0)
return trans_input, mask
class DinoTransform(object):
def __init__(self, image_size, train=True):
args = get_args()
self.data_type = torch.half if args.fp16 else torch.bfloat16
flip_and_color_jitter = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.RandomApply(
[T.ColorJitter(brightness=0.4, contrast=0.4,
saturation=0.2, hue=0.1)],
p=0.8
),
T.RandomGrayscale(p=0.2),
])
if args.fp16 or args.bf16:
normalize = T.Compose([
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
else:
normalize = T.Compose([
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
# first global crop
scale_const = 0.4
self.global_transform1 = T.Compose([
T.RandomResizedCrop(image_size,
scale=(scale_const, 1),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(1.0),
normalize
])
# second global crop
self.global_transform2 = T.Compose([
T.RandomResizedCrop(image_size,
scale=(scale_const, 1),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(0.1),
Solarization(0.2),
normalize
])
# transformation for the local small crops
self.local_crops_number = args.dino_local_crops_number
self.local_transform = T.Compose([
T.RandomResizedCrop(args.dino_local_img_size,
scale=(0.05, scale_const),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(p=0.5),
normalize
])
def __call__(self, image):
crops = []
crops.append(self.global_transform1(image))
crops.append(self.global_transform2(image))
for _ in range(self.local_crops_number):
crops.append(self.local_transform(image))
return crops
def build_train_valid_datasets(data_path, image_size=224):
args = get_args()
if args.vision_pretraining_type == 'classify':
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)
elif args.vision_pretraining_type == 'inpaint':
train_transform = InpaintingTransform(image_size, train=False)
val_transform = InpaintingTransform(image_size, train=False)
elif args.vision_pretraining_type == 'dino':
train_transform = DinoTransform(image_size, train=True)
val_transform = ClassificationTransform(image_size, train=False)
else:
raise Exception('{} vit pretraining type is not supported.'.format(
args.vit_pretraining_type))
# training dataset
train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2]
train_data = ImageFolder(
root=train_data_path,
transform=train_transform,
classes_fraction=args.classes_fraction,
data_per_class_fraction=args.data_per_class_fraction
)
train_data = RandomSeedDataset(train_data)
# validation dataset
val_data_path = os.path.join(data_path[0], "val")
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
fp16_t
]
)
val_data = datasets.ImageFolder(
root=val_data_path, transform=transform_val
val_data_path = data_path[1]
val_data = ImageFolder(
root=val_data_path,
transform=val_transform
)
val_data = RandomSeedDataset(val_data)
return train_data, val_data
import signal
import torch
def get_world_size():
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
else:
world_size = 1
return world_size
def get_device(local_rank=None):
backend = torch.distributed.get_backend()
if backend == 'nccl':
if local_rank is None:
device = torch.device('cuda')
else:
device = torch.device(f'cuda:{local_rank}')
elif backend == 'gloo':
device = torch.device('cpu')
else:
raise RuntimeError
return device
def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None):
if not torch.distributed.is_available() or \
not torch.distributed.is_initialized():
return [item]
device = get_device(local_rank)
if group is not None:
group_size = group.size()
else:
group_size = get_world_size()
tensor = torch.tensor([item], device=device, dtype=dtype)
output_tensors = [
torch.zeros(1, dtype=tensor.dtype, device=tensor.device)
for _ in range(group_size)
]
torch.distributed.all_gather(output_tensors, tensor, group, async_op)
output = [elem.item() for elem in output_tensors]
return output
class DistributedSignalHandler:
def __init__(self, sig=signal.SIGTERM):
self.sig = sig
def signals_received(self):
all_received = all_gather_item(
self._signal_received, dtype=torch.int32
)
return all_received
def __enter__(self):
self._signal_received = False
self.released = False
self.original_handler = signal.getsignal(self.sig)
def handler(signum, frame):
self._signal_received = True
signal.signal(self.sig, handler)
return self
def __exit__(self, type, value, tb):
self.release()
def release(self):
if self.released:
return False
signal.signal(self.sig, self.original_handler)
self.released = True
return True
......@@ -78,6 +78,12 @@ def load(args):
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags)
# Softmax
sources=[srcpath / 'scaled_softmax.cpp',
srcpath / 'scaled_softmax_cuda.cu']
scaled_softmax_cuda = _cpp_extention_load_helper(
"scaled_softmax_cuda", sources, extra_cuda_flags)
# =================================
# Mixed precision fused layer norm.
# =================================
......
......@@ -90,6 +90,117 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
......@@ -326,6 +437,98 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
return batches_per_block;
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 12: // 4096
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
output_t *dst,
......@@ -338,7 +541,7 @@ void dispatch_scaled_masked_softmax_forward(
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
......@@ -410,6 +613,10 @@ void dispatch_scaled_masked_softmax_forward(
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 12: // 4096
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
}
......@@ -427,7 +634,7 @@ void dispatch_scaled_masked_softmax_backward(
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
......@@ -498,6 +705,11 @@ void dispatch_scaled_masked_softmax_backward(
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 12: // 4096
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
}
......
......@@ -44,7 +44,7 @@ torch::Tensor fwd_cuda(
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
......
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(
torch::Tensor const& input,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_softmax::fwd,
"Self Multihead Attention scaled, softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_softmax::bwd,
"Self Multihead Attention scaled, softmax -- Backward.");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment