Unverified Commit f1b3d60c authored by binmakeswell's avatar binmakeswell Committed by GitHub
Browse files

[example] reorganize for community examples (#3557)

parent 1a809edd
...@@ -10,9 +10,12 @@ ...@@ -10,9 +10,12 @@
## Overview ## Overview
This folder provides several examples accelerated by Colossal-AI. The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI. Other folders such as `images` and `language` include a wide range of deep learning tasks and applications. This folder provides several examples accelerated by Colossal-AI.
Folders such as `images` and `language` include a wide range of deep learning tasks and applications.
The `community` folder aim to create a collaborative platform for developers to contribute exotic features built on top of Colossal-AI.
The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI.
You can find applications such as Chatbot, Stable Diffusion and Biomedicine in the [Applications](https://github.com/hpcaitech/ColossalAI/tree/main/applications) directory. You can find applications such as Chatbot, AIGC and Biomedicine in the [Applications](https://github.com/hpcaitech/ColossalAI/tree/main/applications) directory.
## Folder Structure ## Folder Structure
...@@ -52,3 +55,10 @@ Therefore, it is essential for the example contributors to know how to integrate ...@@ -52,3 +55,10 @@ Therefore, it is essential for the example contributors to know how to integrate
2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes. 2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes.
3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine. 3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine.
4. Implement the logic such as dependency setup and example execution 4. Implement the logic such as dependency setup and example execution
## Community Dependency
We are happy to introduce the following nice community dependency repos that are powered by Colossal-AI:
- [lightning-ColossalAI](https://github.com/Lightning-AI/lightning)
- [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion)
- [KoChatGPT](https://github.com/airobotlab/KoChatGPT)
- [minichatgpt](https://github.com/juncongmoo/minichatgpt)
#Community Examples
Community-driven Examples is an initiative that allows users to share their own examples to the Colossal-AI community, fostering a sense of community and making it easy for others to access and benefit from shared work. The primary goal with community-driven examples is to have a community-maintained collection of diverse and exotic functionalities built on top of the Colossal-AI package.
If a community example doesn't work as expected, you can [open an issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) and @ the author to report it.
| Example | Description | Code Example | Colab |Author |
|:------------------|:---------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------|:-----------------------------------------|-----------------------------------------------------:|
| RoBERTa | Adding RoBERTa for SFT and Prompts model training | [RoBERTa](./roberta) | - | [YY Lin](https://github.com/yynil) (Moore Threads) |
| TransformerEngine FP8 | Adding TransformerEngine with FP8 training | [TransformerEngine FP8](./fp8) | - | [Kirthi Shankar Sivamani](https://github.com/ksivaman) (NVIDIA) |
|...|...|...|...|...|
## Looking for Examples
* [Swin-Transformer](https://github.com/microsoft/Swin-Transformer)
* [T-5](https://github.com/google-research/text-to-text-transfer-transformer)
* [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything)
* [ControlNet](https://github.com/lllyasviel/ControlNet)
* [Consistency Models](https://github.com/openai/consistency_models)
* [MAE](https://github.com/facebookresearch/mae)
* [CLIP](https://github.com/openai/CLIP)
Welcome to [open an issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) to share your insights and needs.
## How to get involved
To join our community-driven initiative, please visit the [Colossal-AI examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples), review the provided information, and explore the codebase.
To contribute, create a new issue outlining your proposed feature or enhancement, and our team will review and provide feedback. If you are confident enough you can also submit a PR directly. We look forward to collaborating with you on this exciting project!
# Basic MNIST Example with optional FP8 of TransformerEngine # Basic MNIST Example with optional FP8 of TransformerEngine
[TransformerEngine](https://github.com/NVIDIA/TransformerEngine) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference. [TransformerEngine](https://github.com/NVIDIA/TransformerEngine) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference.
Thanks for the contribution to this tutorial from NVIDIA. Thanks for the contribution to this tutorial from NVIDIA.
```bash ```bash
python main.py python main.py
python main.py --use-te # Linear layers from TransformerEngine python main.py --use-te # Linear layers from TransformerEngine
python main.py --use-fp8 # FP8 + TransformerEngine for Linear layers python main.py --use-fp8 # FP8 + TransformerEngine for Linear layers
``` ```
> We are working to integrate it with Colossal-AI and will finish it soon. > We are working to integrate it with Colossal-AI and will finish it soon.
...@@ -3,12 +3,13 @@ ...@@ -3,12 +3,13 @@
# See LICENSE for license information. # See LICENSE for license information.
import argparse import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
try: try:
from transformer_engine import pytorch as te from transformer_engine import pytorch as te
...@@ -18,6 +19,7 @@ except (ImportError, ModuleNotFoundError): ...@@ -18,6 +19,7 @@ except (ImportError, ModuleNotFoundError):
class Net(nn.Module): class Net(nn.Module):
def __init__(self, use_te=False): def __init__(self, use_te=False):
super(Net, self).__init__() super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv1 = nn.Conv2d(1, 32, 3, 1)
...@@ -62,12 +64,10 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): ...@@ -62,12 +64,10 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if batch_idx % args.log_interval == 0: if batch_idx % args.log_interval == 0:
print( print(f"Train Epoch: {epoch} "
f"Train Epoch: {epoch} " f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
f"({100. * batch_idx / len(train_loader):.0f}%)]\t" f"Loss: {loss.item():.6f}")
f"Loss: {loss.item():.6f}"
)
if args.dry_run: if args.dry_run:
break break
...@@ -83,6 +83,7 @@ def calibrate(model, device, test_loader): ...@@ -83,6 +83,7 @@ def calibrate(model, device, test_loader):
with te.fp8_autocast(enabled=False, calibrating=True): with te.fp8_autocast(enabled=False, calibrating=True):
output = model(data) output = model(data)
def test(model, device, test_loader, use_fp8): def test(model, device, test_loader, use_fp8):
"""Testing function.""" """Testing function."""
model.eval() model.eval()
...@@ -93,21 +94,15 @@ def test(model, device, test_loader, use_fp8): ...@@ -93,21 +94,15 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
output = model(data) output = model(data)
test_loss += F.nll_loss( test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
output, target, reduction="sum" pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item() correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
print( print(f"\nTest set: Average loss: {test_loss:.4f}, "
f"\nTest set: Average loss: {test_loss:.4f}, " f"Accuracy: {correct}/{len(test_loader.dataset)} "
f"Accuracy: {correct}/{len(test_loader.dataset)} " f"({100. * correct / len(test_loader.dataset):.0f}%)\n")
f"({100. * correct / len(test_loader.dataset):.0f}%)\n"
)
def main(): def main():
...@@ -154,9 +149,7 @@ def main(): ...@@ -154,9 +149,7 @@ def main():
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument( parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument( parser.add_argument(
"--log-interval", "--log-interval",
type=int, type=int,
...@@ -170,15 +163,12 @@ def main(): ...@@ -170,15 +163,12 @@ def main():
default=False, default=False,
help="For Saving the current Model", help="For Saving the current Model",
) )
parser.add_argument( parser.add_argument("--use-fp8",
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" action="store_true",
) default=False,
parser.add_argument( help="Use FP8 for inference and training without recalibration")
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only" parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only")
) parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine")
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
)
args = parser.parse_args() args = parser.parse_args()
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
...@@ -205,9 +195,7 @@ def main(): ...@@ -205,9 +195,7 @@ def main():
train_kwargs.update(cuda_kwargs) train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs)
transform = transforms.Compose( transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform) dataset2 = datasets.MNIST("../data", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
...@@ -227,7 +215,7 @@ def main(): ...@@ -227,7 +215,7 @@ def main():
if args.save_model or args.use_fp8_infer: if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt") torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer)) print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt") weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights) model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer) test(model, device, test_loader, args.use_fp8_infer)
......
...@@ -11,7 +11,7 @@ ssh-keygen ...@@ -11,7 +11,7 @@ ssh-keygen
ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination
``` ```
- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below. - In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below.
```bash ```bash
192.168.2.1 GPU001 192.168.2.1 GPU001
...@@ -29,7 +29,7 @@ ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination ...@@ -29,7 +29,7 @@ ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination
service ssh restart service ssh restart
``` ```
## 1. Corpus Preprocessing ## 1. Corpus Preprocessing
```bash ```bash
cd preprocessing cd preprocessing
``` ```
......
...@@ -21,7 +21,7 @@ This folder is used to preprocess chinese corpus with Whole Word Masked. You can ...@@ -21,7 +21,7 @@ This folder is used to preprocess chinese corpus with Whole Word Masked. You can
<span id='Split Sentence'/> <span id='Split Sentence'/>
### 2.1. Split Sentence & Split data into multiple shard: ### 2.1. Split Sentence & Split data into multiple shard:
Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch. Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch.
In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.** In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.**
```python ```python
...@@ -49,7 +49,7 @@ python sentence_split.py --input_path /orginal_corpus --output_path /shard --sha ...@@ -49,7 +49,7 @@ python sentence_split.py --input_path /orginal_corpus --output_path /shard --sha
] ]
``` ```
<summary><b>Output txt:</b></summary> <summary><b>Output txt:</b></summary>
``` ```
我今天去打篮球。 我今天去打篮球。
...@@ -76,7 +76,7 @@ make ...@@ -76,7 +76,7 @@ make
* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... * `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ... * `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ...
* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) * `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main)
* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed** * `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed**
* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document * `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document
* `--worker`: number of process * `--worker`: number of process
...@@ -91,7 +91,7 @@ make ...@@ -91,7 +91,7 @@ make
下周请假。 下周请假。
``` ```
<summary><b>Output h5+numpy:</b></summary> <summary><b>Output h5+numpy:</b></summary>
``` ```
'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..], 'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..],
...@@ -102,4 +102,4 @@ make ...@@ -102,4 +102,4 @@ make
...] ...]
'masked_lm_positions': [[label1,-1,-1,label2,-1...], 'masked_lm_positions': [[label1,-1,-1,label2,-1...],
...] ...]
``` ```
\ No newline at end of file
import torch import collections
import logging
import os import os
from enum import IntEnum
from random import choice
import random import random
import collections
import time import time
import logging from enum import IntEnum
from random import choice
import jieba import jieba
import torch
jieba.setLogLevel(logging.CRITICAL) jieba.setLogLevel(logging.CRITICAL)
import re import re
import numpy as np
import mask import mask
import numpy as np
PAD = 0 PAD = 0
MaskedLMInstance = collections.namedtuple("MaskedLMInstance", MaskedLMInstance = collections.namedtuple("MaskedLMInstance", ["index", "label"])
["index", "label"])
def map_to_numpy(data): def map_to_numpy(data):
...@@ -22,6 +24,7 @@ def map_to_numpy(data): ...@@ -22,6 +24,7 @@ def map_to_numpy(data):
class PreTrainingDataset(): class PreTrainingDataset():
def __init__(self, def __init__(self,
tokenizer, tokenizer,
max_seq_length, max_seq_length,
...@@ -43,17 +46,15 @@ class PreTrainingDataset(): ...@@ -43,17 +46,15 @@ class PreTrainingDataset():
self.mlm_tamper_p = 0.05 self.mlm_tamper_p = 0.05
self.mlm_maintain_p = 0.1 self.mlm_maintain_p = 0.1
def tokenize(self, doc): def tokenize(self, doc):
temp = [] temp = []
for d in doc: for d in doc:
temp.append(self.tokenizer.tokenize(d)) temp.append(self.tokenizer.tokenize(d))
return temp return temp
def create_training_instance(self, instance): def create_training_instance(self, instance):
is_next = 1 is_next = 1
raw_text_list = self.get_new_segment(instance) raw_text_list = self.get_new_segment(instance)
tokens_a = raw_text_list tokens_a = raw_text_list
assert len(tokens_a) == len(instance) assert len(tokens_a) == len(instance)
# tokens_a, tokens_b, is_next = instance.get_values() # tokens_a, tokens_b, is_next = instance.get_values()
...@@ -83,8 +84,9 @@ class PreTrainingDataset(): ...@@ -83,8 +84,9 @@ class PreTrainingDataset():
# Get Masked LM predictions # Get Masked LM predictions
if self.backend == 'c++': if self.backend == 'c++':
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words, output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(
self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob) tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq,
self.masked_lm_prob)
elif self.backend == 'python': elif self.backend == 'python':
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens) output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
...@@ -102,29 +104,25 @@ class PreTrainingDataset(): ...@@ -102,29 +104,25 @@ class PreTrainingDataset():
map_to_numpy(input_mask), map_to_numpy(input_mask),
map_to_numpy(segment_ids), map_to_numpy(segment_ids),
map_to_numpy(masked_lm_output), map_to_numpy(masked_lm_output),
map_to_numpy([is_next]) map_to_numpy([is_next])
]) ])
def create_masked_lm_predictions(self, tokens): def create_masked_lm_predictions(self, tokens):
cand_indexes = [] cand_indexes = []
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]": if token == "[CLS]" or token == "[SEP]":
continue continue
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
token.startswith("##")):
cand_indexes[-1].append(i) cand_indexes[-1].append(i)
else: else:
cand_indexes.append([i]) cand_indexes.append([i])
# cand_indexes.append(i) # cand_indexes.append(i)
random.shuffle(cand_indexes) random.shuffle(cand_indexes)
output_tokens = list(tokens) output_tokens = list(tokens)
num_to_predict = min( num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
self.max_predictions_per_seq,
max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = [] masked_lms = []
covered_indexes = set() covered_indexes = set()
...@@ -145,13 +143,10 @@ class PreTrainingDataset(): ...@@ -145,13 +143,10 @@ class PreTrainingDataset():
masked_token = tokens[index] masked_token = tokens[index]
# 10% replace w/ random word # 10% replace w/ random word
else: else:
masked_token = self.vocab_words[random.randint( masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
0,
len(self.vocab_words) - 1)]
output_tokens[index] = masked_token output_tokens[index] = masked_token
masked_lms.append( masked_lms.append(MaskedLMInstance(index=index, label=tokens[index]))
MaskedLMInstance(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens) masked_lm_output = [-1] * len(output_tokens)
...@@ -160,7 +155,6 @@ class PreTrainingDataset(): ...@@ -160,7 +155,6 @@ class PreTrainingDataset():
return (output_tokens, masked_lm_output) return (output_tokens, masked_lm_output)
def get_new_segment(self, segment): def get_new_segment(self, segment):
""" """
Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word. Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word.
...@@ -171,7 +165,7 @@ class PreTrainingDataset(): ...@@ -171,7 +165,7 @@ class PreTrainingDataset():
new_segment = [] new_segment = []
i = 0 i = 0
while i < len(segment): while i < len(segment):
if len(self.rec.findall(segment[i])) == 0: if len(self.rec.findall(segment[i])) == 0:
new_segment.append(segment[i]) new_segment.append(segment[i])
i += 1 i += 1
continue continue
...@@ -180,10 +174,10 @@ class PreTrainingDataset(): ...@@ -180,10 +174,10 @@ class PreTrainingDataset():
for length in range(3, 0, -1): for length in range(3, 0, -1):
if i + length > len(segment): if i + length > len(segment):
continue continue
if ''.join(segment[i: i+length]) in seq_cws_dict: if ''.join(segment[i:i + length]) in seq_cws_dict:
new_segment.append(segment[i]) new_segment.append(segment[i])
for l in range(1, length): for l in range(1, length):
new_segment.append('##' + segment[i+l]) new_segment.append('##' + segment[i + l])
i += length i += length
has_add = True has_add = True
break break
...@@ -192,7 +186,6 @@ class PreTrainingDataset(): ...@@ -192,7 +186,6 @@ class PreTrainingDataset():
i += 1 i += 1
return new_segment return new_segment
def create_whole_masked_lm_predictions(self, tokens): def create_whole_masked_lm_predictions(self, tokens):
"""Creates the predictions for the masked LM objective.""" """Creates the predictions for the masked LM objective."""
...@@ -209,18 +202,16 @@ class PreTrainingDataset(): ...@@ -209,18 +202,16 @@ class PreTrainingDataset():
# Note that Whole Word Masking does *not* change the training code # Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed # at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary. # over the entire vocabulary.
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
token.startswith("##")):
cand_indexes[-1].append(i) cand_indexes[-1].append(i)
else: else:
cand_indexes.append([i]) cand_indexes.append([i])
random.shuffle(cand_indexes) random.shuffle(cand_indexes)
output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # 去掉"##" output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##"
num_to_predict = min(self.max_predictions_per_seq, num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = [] masked_lms = []
covered_indexes = set() covered_indexes = set()
...@@ -248,14 +239,18 @@ class PreTrainingDataset(): ...@@ -248,14 +239,18 @@ class PreTrainingDataset():
else: else:
# 10% of the time, keep original # 10% of the time, keep original
if random.random() < 0.5: if random.random() < 0.5:
masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # 去掉"##" masked_token = tokens[index][2:] if len(self.whole_rec.findall(
tokens[index])) > 0 else tokens[index] # 去掉"##"
# 10% of the time, replace with random word # 10% of the time, replace with random word
else: else:
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
output_tokens[index] = masked_token output_tokens[index] = masked_token
masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index])) masked_lms.append(
MaskedLMInstance(
index=index,
label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index]))
assert len(masked_lms) <= num_to_predict assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens) masked_lm_output = [-1] * len(output_tokens)
......
#include <math.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <algorithm>
#include <chrono>
#include <iostream>
#include <limits>
#include <random>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace py = pybind11;
const int32_t LONG_SENTENCE_LEN = 512;
struct MaskedLMInstance {
int index;
std::string label;
MaskedLMInstance(int index, std::string label) {
this->index = index;
this->label = label;
}
};
auto get_new_segment(
std::vector<std::string> segment, std::vector<std::string> segment_jieba,
const std::vector<bool> chinese_vocab) { // const
// std::unordered_set<std::string>
// &chinese_vocab
std::unordered_set<std::string> seq_cws_dict;
for (auto word : segment_jieba) {
seq_cws_dict.insert(word);
}
int i = 0;
std::vector<std::string> new_segment;
int segment_size = segment.size();
while (i < segment_size) {
if (!chinese_vocab[i]) { // chinese_vocab.find(segment[i]) ==
// chinese_vocab.end()
new_segment.emplace_back(segment[i]);
i += 1;
continue;
}
bool has_add = false;
for (int length = 3; length >= 1; length--) {
if (i + length > segment_size) {
continue;
}
std::string chinese_word = "";
for (int j = i; j < i + length; j++) {
chinese_word += segment[j];
}
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
new_segment.emplace_back(segment[i]);
for (int j = i + 1; j < i + length; j++) {
new_segment.emplace_back("##" + segment[j]);
}
i += length;
has_add = true;
break;
}
}
if (!has_add) {
new_segment.emplace_back(segment[i]);
i += 1;
}
}
return new_segment;
}
bool startsWith(const std::string &s, const std::string &sub) {
return s.find(sub) == 0 ? true : false;
}
auto create_whole_masked_lm_predictions(
std::vector<std::string> &tokens,
const std::vector<std::string> &original_tokens,
const std::vector<std::string> &vocab_words,
std::map<std::string, int> &vocab, const int max_predictions_per_seq,
const double masked_lm_prob) {
// for (auto item : vocab) {
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
// << "value=" << std::string(py::str(item.second)) <<
// std::endl;
// }
std::vector<std::vector<int> > cand_indexes;
std::vector<int> cand_temp;
int tokens_size = tokens.size();
std::string prefix = "##";
bool do_whole_masked = true;
for (int i = 0; i < tokens_size; i++) {
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
continue;
}
if (do_whole_masked && (cand_indexes.size() > 0) &&
(tokens[i].rfind(prefix, 0) == 0)) {
cand_temp.emplace_back(i);
} else {
if (cand_temp.size() > 0) {
cand_indexes.emplace_back(cand_temp);
}
cand_temp.clear();
cand_temp.emplace_back(i);
}
}
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
std::shuffle(cand_indexes.begin(), cand_indexes.end(),
std::default_random_engine(seed));
// for (auto i : cand_indexes) {
// for (auto j : i) {
// std::cout << tokens[j] << " ";
// }
// std::cout << std::endl;
// }
// for (auto i : output_tokens) {
// std::cout << i;
// }
// std::cout << std::endl;
int num_to_predict = std::min(max_predictions_per_seq,
std::max(1, int(tokens_size * masked_lm_prob)));
// std::cout << num_to_predict << std::endl;
std::set<int> covered_indexes;
std::vector<int> masked_lm_output(tokens_size, -1);
int vocab_words_len = vocab_words.size();
std::default_random_engine e(seed);
std::uniform_real_distribution<double> u1(0.0, 1.0);
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
int mask_cnt = 0;
std::vector<std::string> output_tokens;
output_tokens = original_tokens;
for (auto index_set : cand_indexes) {
if (mask_cnt > num_to_predict) {
break;
}
int index_set_size = index_set.size();
if (mask_cnt + index_set_size > num_to_predict) {
continue;
}
bool is_any_index_covered = false;
for (auto index : index_set) {
if (covered_indexes.find(index) != covered_indexes.end()) {
is_any_index_covered = true;
break;
}
}
if (is_any_index_covered) {
continue;
}
for (auto index : index_set) {
covered_indexes.insert(index);
std::string masked_token;
if (u1(e) < 0.8) {
masked_token = "[MASK]";
} else {
if (u1(e) < 0.5) {
masked_token = output_tokens[index];
} else {
int random_index = u2(e);
masked_token = vocab_words[random_index];
}
}
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
masked_lm_output[index] = vocab[output_tokens[index]];
output_tokens[index] = masked_token;
mask_cnt++;
}
}
// for (auto p : masked_lms) {
// masked_lm_output[p.index] = vocab[p.label];
// }
return std::make_tuple(output_tokens, masked_lm_output);
}
PYBIND11_MODULE(mask, m) {
m.def("create_whole_masked_lm_predictions",
&create_whole_masked_lm_predictions);
m.def("get_new_segment", &get_new_segment);
}
import argparse
import functools
import json
import multiprocessing import multiprocessing
import os import os
import re import re
from tqdm import tqdm
from typing import List
import json
import time import time
import argparse from typing import List
import functools
from tqdm import tqdm
def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]: def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
sent_list = [] sent_list = []
try: try:
if flag == "zh": if flag == "zh":
document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document) document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document)
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document) document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)
elif flag == "en": elif flag == "en":
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document) document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n', document) # Special quotation marks document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n',
document) # Special quotation marks
else: else:
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document) document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n', document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
document) # Special quotation marks document) # Special quotation marks
sent_list_ori = document.splitlines() sent_list_ori = document.splitlines()
for sent in sent_list_ori: for sent in sent_list_ori:
...@@ -43,17 +45,15 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s ...@@ -43,17 +45,15 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s
return sent_list return sent_list
def get_sent(output_path, def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None:
input_path,
fin_list=[], host=-1, seq_len=512) -> None:
workers = 32 workers = 32
if input_path[-1] == '/': if input_path[-1] == '/':
input_path = input_path[:-1] input_path = input_path[:-1]
cur_path = os.path.join(output_path, str(host) + '.txt') cur_path = os.path.join(output_path, str(host) + '.txt')
new_split_sentence = functools.partial(split_sentence, limit=seq_len-2) new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2)
with open(cur_path, 'w', encoding='utf-8') as f: with open(cur_path, 'w', encoding='utf-8') as f:
for fi, fin_path in enumerate(fin_list): for fi, fin_path in enumerate(fin_list):
if not os.path.exists(os.path.join(input_path, fin_path[0])): if not os.path.exists(os.path.join(input_path, fin_path[0])):
...@@ -62,7 +62,7 @@ def get_sent(output_path, ...@@ -62,7 +62,7 @@ def get_sent(output_path,
continue continue
print("Processing ", fin_path[0], " ", fi) print("Processing ", fin_path[0], " ", fi)
with open(os.path.join(input_path, fin_path[0]), 'r') as fin: with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
f_data = [l['content'] for l in json.load(fin)] f_data = [l['content'] for l in json.load(fin)]
...@@ -99,17 +99,17 @@ def getFileSize(filepath, shard): ...@@ -99,17 +99,17 @@ def getFileSize(filepath, shard):
real_shard.append(temp) real_shard.append(temp)
accu_size = 0 accu_size = 0
temp = [] temp = []
if len(temp) > 0: if len(temp) > 0:
real_shard.append(temp) real_shard.append(temp)
return real_shard return real_shard
def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
import socket import socket
host = int(socket.gethostname().split(server_name)[-1]) host = int(socket.gethostname().split(server_name)[-1])
fin_list = real_shard[server_num * base + host - 1] fin_list = real_shard[server_num * base + host - 1]
print(fin_list) print(fin_list)
print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}') print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
...@@ -126,28 +126,24 @@ if __name__ == '__main__': ...@@ -126,28 +126,24 @@ if __name__ == '__main__':
parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence') parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
args = parser.parse_args() args = parser.parse_args()
server_num = args.server_num server_num = args.server_num
seq_len = args.seq_len seq_len = args.seq_len
shard = args.shard shard = args.shard
input_path = args.input_path input_path = args.input_path
output_path = args.output_path output_path = args.output_path
real_shard = getFileSize(input_path, shard) real_shard = getFileSize(input_path, shard)
start = time.time() start = time.time()
for index, shard in enumerate(real_shard): for index, shard in enumerate(real_shard):
get_sent(output_path, get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len)
input_path,
fin_list=shard,
host=index,
seq_len=seq_len)
print(f'cost {str(time.time() - start)}') print(f'cost {str(time.time() - start)}')
# if you have multiple server, you can use code below or modify code to openmpi # if you have multiple server, you can use code below or modify code to openmpi
# for i in range(len(real_shard) // server_num + 1): # for i in range(len(real_shard) // server_num + 1):
# fin_list, host = get_start_end(real_shard, i) # fin_list, host = get_start_end(real_shard, i)
# start = time.time() # start = time.time()
# get_sent(output_path, # get_sent(output_path,
# input_path, # input_path,
......
import time import argparse
import multiprocessing
import os import os
import psutil
import h5py
import socket import socket
import argparse import time
from random import shuffle
import h5py
import numpy as np import numpy as np
import multiprocessing import psutil
from get_mask import PreTrainingDataset
from tqdm import tqdm from tqdm import tqdm
from random import shuffle
from transformers import AutoTokenizer from transformers import AutoTokenizer
from get_mask import PreTrainingDataset
def get_raw_instance(document, max_sequence_length=512): def get_raw_instance(document, max_sequence_length=512):
""" """
Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances. Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances.
:param document: document :param document: document
...@@ -26,24 +26,24 @@ def get_raw_instance(document, max_sequence_length=512): ...@@ -26,24 +26,24 @@ def get_raw_instance(document, max_sequence_length=512):
sizes = [len(seq) for seq in document] sizes = [len(seq) for seq in document]
result_list = [] result_list = []
curr_seq = [] curr_seq = []
sz_idx = 0 sz_idx = 0
while sz_idx < len(sizes): while sz_idx < len(sizes):
if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0: if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0:
curr_seq += document[sz_idx] curr_seq += document[sz_idx]
sz_idx += 1 sz_idx += 1
elif sizes[sz_idx] >= max_sequence_length_allowed: elif sizes[sz_idx] >= max_sequence_length_allowed:
if len(curr_seq) > 0: if len(curr_seq) > 0:
result_list.append(curr_seq) result_list.append(curr_seq)
curr_seq = [] curr_seq = []
result_list.append(document[sz_idx][ : max_sequence_length_allowed]) result_list.append(document[sz_idx][:max_sequence_length_allowed])
sz_idx += 1 sz_idx += 1
else: else:
result_list.append(curr_seq) result_list.append(curr_seq)
curr_seq = [] curr_seq = []
if len(curr_seq) > max_sequence_length_allowed / 2: # /2 if len(curr_seq) > max_sequence_length_allowed / 2: # /2
result_list.append(curr_seq) result_list.append(curr_seq)
# num_instance=int(len(big_list)/max_sequence_length_allowed)+1 # num_instance=int(len(big_list)/max_sequence_length_allowed)+1
...@@ -70,8 +70,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): ...@@ -70,8 +70,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
# document = line # document = line
# if len(document.split("<sep>")) <= 3: # if len(document.split("<sep>")) <= 3:
# continue # continue
if len(line if len(line) > 0 and line[:2] == "]]": # This is end of document
) > 0 and line[:2] == "]]": # This is end of document
documents.append(document) documents.append(document)
document = [] document = []
elif len(line) >= 2: elif len(line) >= 2:
...@@ -84,8 +83,8 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): ...@@ -84,8 +83,8 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
# print(len(documents)) # print(len(documents))
# print(len(documents[0])) # print(len(documents[0]))
# print(documents[0][0:10]) # print(documents[0][0:10])
from typing import List
import multiprocessing import multiprocessing
from typing import List
ans = [] ans = []
for docs in tqdm(documents): for docs in tqdm(documents):
...@@ -98,7 +97,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): ...@@ -98,7 +97,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
raw_ins = get_raw_instance(a) raw_ins = get_raw_instance(a)
instances.extend(raw_ins) instances.extend(raw_ins)
del ans del ans
print('len instance', len(instances)) print('len instance', len(instances))
sen_num = len(instances) sen_num = len(instances)
...@@ -116,21 +115,15 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): ...@@ -116,21 +115,15 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
masked_lm_output[index] = mask_dict[3] masked_lm_output[index] = mask_dict[3]
with h5py.File(f'/output/{host}.h5', 'w') as hf: with h5py.File(f'/output/{host}.h5', 'w') as hf:
hf.create_dataset("input_ids", data=input_ids) hf.create_dataset("input_ids", data=input_ids)
hf.create_dataset("input_mask", data=input_ids) hf.create_dataset("input_mask", data=input_ids)
hf.create_dataset("segment_ids", data=segment_ids) hf.create_dataset("segment_ids", data=segment_ids)
hf.create_dataset("masked_lm_positions", data=masked_lm_output) hf.create_dataset("masked_lm_positions", data=masked_lm_output)
del instances del instances
def split_numpy_chunk_pool(input_path, def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name):
output_path,
pretrain_data,
worker,
dupe_factor,
seq_len,
file_name):
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
print(f'{file_name}.h5 exists') print(f'{file_name}.h5 exists')
...@@ -144,8 +137,7 @@ def split_numpy_chunk_pool(input_path, ...@@ -144,8 +137,7 @@ def split_numpy_chunk_pool(input_path,
document = [] document = []
for i, line in enumerate(tqdm(fd)): for i, line in enumerate(tqdm(fd)):
line = line.strip() line = line.strip()
if len(line if len(line) > 0 and line[:2] == "]]": # This is end of document
) > 0 and line[:2] == "]]": # This is end of document
documents.append(document) documents.append(document)
document = [] document = []
elif len(line) >= 2: elif len(line) >= 2:
...@@ -153,7 +145,7 @@ def split_numpy_chunk_pool(input_path, ...@@ -153,7 +145,7 @@ def split_numpy_chunk_pool(input_path,
if len(document) > 0: if len(document) > 0:
documents.append(document) documents.append(document)
print(f'read_file cost {time.time() - s}, length is {len(documents)}') print(f'read_file cost {time.time() - s}, length is {len(documents)}')
ans = [] ans = []
s = time.time() s = time.time()
pool = multiprocessing.Pool(worker) pool = multiprocessing.Pool(worker)
...@@ -169,7 +161,7 @@ def split_numpy_chunk_pool(input_path, ...@@ -169,7 +161,7 @@ def split_numpy_chunk_pool(input_path,
raw_ins = get_raw_instance(a, max_sequence_length=seq_len) raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
instances.extend(raw_ins) instances.extend(raw_ins)
del ans del ans
print('len instance', len(instances)) print('len instance', len(instances))
new_instances = [] new_instances = []
...@@ -199,10 +191,10 @@ def split_numpy_chunk_pool(input_path, ...@@ -199,10 +191,10 @@ def split_numpy_chunk_pool(input_path,
print((time.time() - s) / 60) print((time.time() - s) / 60)
with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf: with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf:
hf.create_dataset("input_ids", data=input_ids) hf.create_dataset("input_ids", data=input_ids)
hf.create_dataset("input_mask", data=input_mask) hf.create_dataset("input_mask", data=input_mask)
hf.create_dataset("segment_ids", data=segment_ids) hf.create_dataset("segment_ids", data=segment_ids)
hf.create_dataset("masked_lm_positions", data=masked_lm_output) hf.create_dataset("masked_lm_positions", data=masked_lm_output)
del instances del instances
...@@ -212,22 +204,31 @@ if __name__ == '__main__': ...@@ -212,22 +204,31 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
parser.add_argument('--seq_len', type=int, default=512, help='sequence length') parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100') parser.add_argument('--max_predictions_per_seq',
type=int,
default=80,
help='number of shards, e.g., 10, 50, or 100')
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively') parser.add_argument('--backend',
parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document') type=str,
default='python',
help='backend of mask token, python, c++, numpy respectively')
parser.add_argument(
'--dupe_factor',
type=int,
default=1,
help='specifies how many times the preprocessor repeats to create the input from the same article/document')
parser.add_argument('--worker', type=int, default=32, help='number of process') parser.add_argument('--worker', type=int, default=32, help='number of process')
parser.add_argument('--server_num', type=int, default=10, help='number of servers') parser.add_argument('--server_num', type=int, default=10, help='number of servers')
args = parser.parse_args() args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
pretrain_data = PreTrainingDataset(tokenizer, pretrain_data = PreTrainingDataset(tokenizer,
args.seq_len, args.seq_len,
args.backend, args.backend,
max_predictions_per_seq=args.max_predictions_per_seq) max_predictions_per_seq=args.max_predictions_per_seq)
data_len = len(os.listdir(args.input_path)) data_len = len(os.listdir(args.input_path))
for i in range(data_len): for i in range(data_len):
...@@ -235,15 +236,10 @@ if __name__ == '__main__': ...@@ -235,15 +236,10 @@ if __name__ == '__main__':
if os.path.exists(input_path): if os.path.exists(input_path):
start = time.time() start = time.time()
print(f'process {input_path}') print(f'process {input_path}')
split_numpy_chunk_pool(input_path, split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor,
args.output_path, args.seq_len, i)
pretrain_data,
args.worker,
args.dupe_factor,
args.seq_len,
i)
end_ = time.time() end_ = time.time()
print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
print(f'has cost {(end_ - start) / 60}') print(f'has cost {(end_ - start) / 60}')
print('-' * 100) print('-' * 100)
print('') print('')
...@@ -257,9 +253,9 @@ if __name__ == '__main__': ...@@ -257,9 +253,9 @@ if __name__ == '__main__':
# if os.path.exists(input_path): # if os.path.exists(input_path):
# start = time.time() # start = time.time()
# print(f'I am server {host}, process {input_path}') # print(f'I am server {host}, process {input_path}')
# split_numpy_chunk_pool(input_path, # split_numpy_chunk_pool(input_path,
# args.output_path, # args.output_path,
# pretrain_data, # pretrain_data,
# args.worker, # args.worker,
# args.dupe_factor, # args.dupe_factor,
# args.seq_len, # args.seq_len,
...@@ -269,5 +265,3 @@ if __name__ == '__main__': ...@@ -269,5 +265,3 @@ if __name__ == '__main__':
# print(f'has cost {(end_ - start) / 60}') # print(f'has cost {(end_ - start) / 60}')
# print('-' * 100) # print('-' * 100)
# print('') # print('')
...@@ -19,6 +19,5 @@ bash run_pretrain.sh ...@@ -19,6 +19,5 @@ bash run_pretrain.sh
bash run_pretrain_resume.sh bash run_pretrain_resume.sh
``` ```
* `--resume_train`: whether to resume training * `--resume_train`: whether to resume training
* `--load_pretrain_model`: absolute path which contains model checkpoint * `--load_pretrain_model`: absolute path which contains model checkpoint
* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint * `--load_optimizer_lr`: absolute path which contains optimizer checkpoint
from numpy import require
import colossalai
__all__ = ['parse_args']
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--distplan",
type=str,
default='CAI_Gemini',
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
action='store_true',
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
parser.add_argument('--lr', type=float, required=True, help='initial learning rate')
parser.add_argument('--epoch', type=int, required=True, help='number of epoch')
parser.add_argument('--data_path_prefix', type=str, required=True, help="location of the train data corpus")
parser.add_argument('--eval_data_path_prefix',
type=str,
required=True,
help='location of the evaluation data corpus')
parser.add_argument('--tokenizer_path', type=str, required=True, help='location of the tokenizer')
parser.add_argument('--max_seq_length', type=int, default=512, help='sequence length')
parser.add_argument('--refresh_bucket_size',
type=int,
default=1,
help="This param makes sure that a certain task is repeated for this time steps to \
optimise on the back propogation speed with APEX's DistributedDataParallel")
parser.add_argument("--max_predictions_per_seq",
"--max_pred",
default=80,
type=int,
help="The maximum number of masked tokens in a sequence to be predicted.")
parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="accumulation_steps")
parser.add_argument("--train_micro_batch_size_per_gpu", default=2, type=int, required=True, help="train batch size")
parser.add_argument("--eval_micro_batch_size_per_gpu", default=2, type=int, required=True, help="eval batch size")
parser.add_argument("--num_workers", default=8, type=int, help="")
parser.add_argument("--async_worker", action='store_true', help="")
parser.add_argument("--bert_config", required=True, type=str, help="location of config.json")
parser.add_argument("--wandb", action='store_true', help="use wandb to watch model")
parser.add_argument("--wandb_project_name", default='roberta', help="wandb project name")
parser.add_argument("--log_interval", default=100, type=int, help="report interval")
parser.add_argument("--log_path", type=str, required=True, help="log file which records train step")
parser.add_argument("--tensorboard_path", type=str, required=True, help="location of tensorboard file")
parser.add_argument("--colossal_config",
type=str,
required=True,
help="colossal config, which contains zero config and so on")
parser.add_argument("--ckpt_path",
type=str,
required=True,
help="location of saving checkpoint, which contains model and optimizer")
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument('--vscode_debug', action='store_true', help="use vscode to debug")
parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoin")
parser.add_argument(
'--load_optimizer_lr',
default='',
type=str,
help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
parser.add_argument('--resume_train', action='store_true', help="whether resume training from a early checkpoint")
parser.add_argument('--mlm', default='bert', type=str, help="model type, bert or deberta")
parser.add_argument('--checkpoint_activations', action='store_true', help="whether to use gradient checkpointing")
args = parser.parse_args()
return args
class BertDatasetProviderInterface: class BertDatasetProviderInterface:
def get_shard(self, index, shuffle=True): def get_shard(self, index, shuffle=True):
raise NotImplementedError raise NotImplementedError
......
import os
import math import math
import os
import torch import torch
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
from tqdm import tqdm from tqdm import tqdm
from utils.global_vars import get_timers, get_tensorboard_writer from utils.global_vars import get_tensorboard_writer, get_timers
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
def evaluate(model, args, logger, global_step, criterion): def evaluate(model, args, logger, global_step, criterion):
evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True) evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True)
...@@ -20,16 +22,19 @@ def evaluate(model, args, logger, global_step, criterion): ...@@ -20,16 +22,19 @@ def evaluate(model, args, logger, global_step, criterion):
for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))): for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))):
timers('eval_shard_time').start() timers('eval_shard_time').start()
dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard) dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard)
# evaluate_dataset_provider.prefetch_shard(shard + 1) # evaluate_dataset_provider.prefetch_shard(shard + 1)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1) iterator_data = tqdm(enumerate(dataset_iterator),
total=(total_length // args.eval_micro_batch_size_per_gpu // world_size),
colour='MAGENTA',
smoothing=1)
else: else:
iterator_data = enumerate(dataset_iterator) iterator_data = enumerate(dataset_iterator)
for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1):
# batch_data = pretrain_dataset_provider.get_batch(batch_index) # batch_data = pretrain_dataset_provider.get_batch(batch_index)
eval_step += 1 eval_step += 1
...@@ -40,8 +45,8 @@ def evaluate(model, args, logger, global_step, criterion): ...@@ -40,8 +45,8 @@ def evaluate(model, args, logger, global_step, criterion):
# nsp_label = batch_data[5].cuda() # nsp_label = batch_data[5].cuda()
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
loss = criterion(output.logits, mlm_label)#prediction_scores loss = criterion(output.logits, mlm_label) #prediction_scores
evaluate_dataset_provider.prefetch_batch() evaluate_dataset_provider.prefetch_batch()
eval_loss += loss.float().item() eval_loss += loss.float().item()
...@@ -54,10 +59,10 @@ def evaluate(model, args, logger, global_step, criterion): ...@@ -54,10 +59,10 @@ def evaluate(model, args, logger, global_step, criterion):
if args.wandb and torch.distributed.get_rank() == 0: if args.wandb and torch.distributed.get_rank() == 0:
tensorboard_log = get_tensorboard_writer() tensorboard_log = get_tensorboard_writer()
tensorboard_log.log_eval({ tensorboard_log.log_eval({
'loss': cur_loss, 'loss': cur_loss,
'ppl': ppl, 'ppl': ppl,
'mins_batch': elapsed_time_per_iteration 'mins_batch': elapsed_time_per_iteration
}, global_step) }, global_step)
eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}' f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}'
...@@ -68,4 +73,4 @@ def evaluate(model, args, logger, global_step, criterion): ...@@ -68,4 +73,4 @@ def evaluate(model, args, logger, global_step, criterion):
evaluate_dataset_provider.release_shard() evaluate_dataset_provider.release_shard()
model.train() model.train()
return cur_loss return cur_loss
\ No newline at end of file
...@@ -13,5 +13,5 @@ class LossForPretraining(torch.nn.Module): ...@@ -13,5 +13,5 @@ class LossForPretraining(torch.nn.Module):
def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None): def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None):
masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
# next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1)) # next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
total_loss = masked_lm_loss #+ next_sentence_loss total_loss = masked_lm_loss #+ next_sentence_loss
return total_loss return total_loss
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch BERT model.""" """PyTorch BERT model."""
import math import math
import os import os
import warnings import warnings
...@@ -27,7 +26,6 @@ import torch.utils.checkpoint ...@@ -27,7 +26,6 @@ import torch.utils.checkpoint
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -41,8 +39,9 @@ from transformers.modeling_outputs import ( ...@@ -41,8 +39,9 @@ from transformers.modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.bert.configuration_bert import BertConfig
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.utils import ( from transformers.utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -50,8 +49,6 @@ from transformers.utils import ( ...@@ -50,8 +49,6 @@ from transformers.utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from transformers.models.bert.configuration_bert import BertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -62,8 +59,7 @@ _TOKENIZER_FOR_DOC = "BertTokenizer" ...@@ -62,8 +59,7 @@ _TOKENIZER_FOR_DOC = "BertTokenizer"
# TokenClassification docstring # TokenClassification docstring
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
_TOKEN_CLASS_EXPECTED_OUTPUT = ( _TOKEN_CLASS_EXPECTED_OUTPUT = (
"['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] ")
)
_TOKEN_CLASS_EXPECTED_LOSS = 0.01 _TOKEN_CLASS_EXPECTED_LOSS = 0.01
# QuestionAnswering docstring # QuestionAnswering docstring
...@@ -78,7 +74,6 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-pol ...@@ -78,7 +74,6 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-pol
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
_SEQ_CLASS_EXPECTED_LOSS = 0.01 _SEQ_CLASS_EXPECTED_LOSS = 0.01
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bert-base-uncased", "bert-base-uncased",
"bert-large-uncased", "bert-large-uncased",
...@@ -114,10 +109,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): ...@@ -114,10 +109,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
except ImportError: except ImportError:
logger.error( logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " "https://www.tensorflow.org/install/ for installation instructions.")
"https://www.tensorflow.org/install/ for installation instructions."
)
raise raise
tf_path = os.path.abspath(tf_checkpoint_path) tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}") logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
...@@ -135,10 +128,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): ...@@ -135,10 +128,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
name = name.split("/") name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if any( if any(n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] for n in name):
for n in name
):
logger.info(f"Skipping {'/'.join(name)}") logger.info(f"Skipping {'/'.join(name)}")
continue continue
pointer = model pointer = model
...@@ -218,7 +209,7 @@ class BertEmbeddings(nn.Module): ...@@ -218,7 +209,7 @@ class BertEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
...@@ -245,13 +236,12 @@ class BertEmbeddings(nn.Module): ...@@ -245,13 +236,12 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module): class BertSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})")
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
...@@ -262,9 +252,7 @@ class BertSelfAttention(nn.Module): ...@@ -262,9 +252,7 @@ class BertSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr( self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
...@@ -332,7 +320,7 @@ class BertSelfAttention(nn.Module): ...@@ -332,7 +320,7 @@ class BertSelfAttention(nn.Module):
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key": if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
...@@ -372,6 +360,7 @@ class BertSelfAttention(nn.Module): ...@@ -372,6 +360,7 @@ class BertSelfAttention(nn.Module):
class BertSelfOutput(nn.Module): class BertSelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
...@@ -386,6 +375,7 @@ class BertSelfOutput(nn.Module): ...@@ -386,6 +375,7 @@ class BertSelfOutput(nn.Module):
class BertAttention(nn.Module): class BertAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
...@@ -395,9 +385,8 @@ class BertAttention(nn.Module): ...@@ -395,9 +385,8 @@ class BertAttention(nn.Module):
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
heads, index = find_pruneable_heads_and_indices( heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads,
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads self.self.attention_head_size, self.pruned_heads)
)
# Prune linear layers # Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index) self.self.query = prune_linear_layer(self.self.query, index)
...@@ -430,11 +419,12 @@ class BertAttention(nn.Module): ...@@ -430,11 +419,12 @@ class BertAttention(nn.Module):
output_attentions, output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
class BertIntermediate(nn.Module): class BertIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
...@@ -450,6 +440,7 @@ class BertIntermediate(nn.Module): ...@@ -450,6 +440,7 @@ class BertIntermediate(nn.Module):
class BertOutput(nn.Module): class BertOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
...@@ -464,6 +455,7 @@ class BertOutput(nn.Module): ...@@ -464,6 +455,7 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module): class BertLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
...@@ -504,15 +496,14 @@ class BertLayer(nn.Module): ...@@ -504,15 +496,14 @@ class BertLayer(nn.Module):
outputs = self_attention_outputs[1:-1] outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1] present_key_value = self_attention_outputs[-1]
else: else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"): if not hasattr(self, "crossattention"):
raise ValueError( raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`" " by setting `config.add_cross_attention=True`")
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
...@@ -526,15 +517,14 @@ class BertLayer(nn.Module): ...@@ -526,15 +517,14 @@ class BertLayer(nn.Module):
output_attentions, output_attentions,
) )
attention_output = cross_attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple # add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1] cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward,
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.seq_len_dim, attention_output)
)
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output # if decoder, return the attn key/values as the last output
...@@ -550,6 +540,7 @@ class BertLayer(nn.Module): ...@@ -550,6 +540,7 @@ class BertLayer(nn.Module):
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -585,11 +576,11 @@ class BertEncoder(nn.Module): ...@@ -585,11 +576,11 @@ class BertEncoder(nn.Module):
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
)
use_cache = False use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions) return module(*inputs, past_key_value, output_attentions)
...@@ -626,17 +617,13 @@ class BertEncoder(nn.Module): ...@@ -626,17 +617,13 @@ class BertEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple( return tuple(v for v in [
v hidden_states,
for v in [ next_decoder_cache,
hidden_states, all_hidden_states,
next_decoder_cache, all_self_attentions,
all_hidden_states, all_cross_attentions,
all_self_attentions, ] if v is not None)
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache, past_key_values=next_decoder_cache,
...@@ -647,6 +634,7 @@ class BertEncoder(nn.Module): ...@@ -647,6 +634,7 @@ class BertEncoder(nn.Module):
class BertPooler(nn.Module): class BertPooler(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
...@@ -662,6 +650,7 @@ class BertPooler(nn.Module): ...@@ -662,6 +650,7 @@ class BertPooler(nn.Module):
class BertPredictionHeadTransform(nn.Module): class BertPredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
...@@ -679,6 +668,7 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -679,6 +668,7 @@ class BertPredictionHeadTransform(nn.Module):
class BertLMPredictionHead(nn.Module): class BertLMPredictionHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.transform = BertPredictionHeadTransform(config) self.transform = BertPredictionHeadTransform(config)
...@@ -699,6 +689,7 @@ class BertLMPredictionHead(nn.Module): ...@@ -699,6 +689,7 @@ class BertLMPredictionHead(nn.Module):
class BertOnlyMLMHead(nn.Module): class BertOnlyMLMHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.predictions = BertLMPredictionHead(config) self.predictions = BertLMPredictionHead(config)
...@@ -709,6 +700,7 @@ class BertOnlyMLMHead(nn.Module): ...@@ -709,6 +700,7 @@ class BertOnlyMLMHead(nn.Module):
class BertOnlyNSPHead(nn.Module): class BertOnlyNSPHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2) self.seq_relationship = nn.Linear(config.hidden_size, 2)
...@@ -719,6 +711,7 @@ class BertOnlyNSPHead(nn.Module): ...@@ -719,6 +711,7 @@ class BertOnlyNSPHead(nn.Module):
class BertPreTrainingHeads(nn.Module): class BertPreTrainingHeads(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.predictions = BertLMPredictionHead(config) self.predictions = BertLMPredictionHead(config)
...@@ -950,9 +943,8 @@ class BertModel(BertPreTrainedModel): ...@@ -950,9 +943,8 @@ class BertModel(BertPreTrainedModel):
`past_key_values`). `past_key_values`).
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (output_hidden_states
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder: if self.config.is_decoder:
...@@ -1051,6 +1043,7 @@ class BertModel(BertPreTrainedModel): ...@@ -1051,6 +1043,7 @@ class BertModel(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForPreTraining(BertPreTrainedModel): class BertForPreTraining(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1151,9 +1144,8 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -1151,9 +1144,8 @@ class BertForPreTraining(BertPreTrainedModel):
) )
@add_start_docstrings( @add_start_docstrings("""Bert Model with a `language modeling` head on top for CLM fine-tuning.""",
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING BERT_START_DOCSTRING)
)
class BertLMHeadModel(BertPreTrainedModel): class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
...@@ -1298,10 +1290,8 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -1298,10 +1290,8 @@ class BertForMaskedLM(BertPreTrainedModel):
super().__init__(config) super().__init__(config)
if config.is_decoder: if config.is_decoder:
logger.warning( logger.warning("If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " "bi-directional self-attention.")
"bi-directional self-attention."
)
self.bert = BertModel(config, add_pooling_layer=False) self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config) self.cls = BertOnlyMLMHead(config)
...@@ -1367,7 +1357,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -1367,7 +1357,7 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
...@@ -1390,9 +1380,10 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -1390,9 +1380,10 @@ class BertForMaskedLM(BertPreTrainedModel):
raise ValueError("The PAD token should be defined for generation") raise ValueError("The PAD token should be defined for generation")
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
dummy_token = torch.full( dummy_token = torch.full((effective_batch_size, 1),
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device self.config.pad_token_id,
) dtype=torch.long,
device=input_ids.device)
input_ids = torch.cat([input_ids, dummy_token], dim=1) input_ids = torch.cat([input_ids, dummy_token], dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask} return {"input_ids": input_ids, "attention_mask": attention_mask}
...@@ -1403,6 +1394,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -1403,6 +1394,7 @@ class BertForMaskedLM(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForNextSentencePrediction(BertPreTrainedModel): class BertForNextSentencePrediction(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1508,15 +1500,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -1508,15 +1500,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForSequenceClassification(BertPreTrainedModel): class BertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config self.config = config
self.bert = BertModel(config) self.bert = BertModel(config)
classifier_dropout = ( classifier_dropout = (config.classifier_dropout
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob if config.classifier_dropout is not None else config.hidden_dropout_prob)
)
self.dropout = nn.Dropout(classifier_dropout) self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
...@@ -1612,13 +1604,13 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1612,13 +1604,13 @@ class BertForSequenceClassification(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForMultipleChoice(BertPreTrainedModel): class BertForMultipleChoice(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.bert = BertModel(config) self.bert = BertModel(config)
classifier_dropout = ( classifier_dropout = (config.classifier_dropout
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob if config.classifier_dropout is not None else config.hidden_dropout_prob)
)
self.dropout = nn.Dropout(classifier_dropout) self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
...@@ -1658,11 +1650,8 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1658,11 +1650,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = ( inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) if inputs_embeds is not None else None)
if inputs_embeds is not None
else None
)
outputs = self.bert( outputs = self.bert(
input_ids, input_ids,
...@@ -1715,9 +1704,8 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1715,9 +1704,8 @@ class BertForTokenClassification(BertPreTrainedModel):
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.bert = BertModel(config, add_pooling_layer=False) self.bert = BertModel(config, add_pooling_layer=False)
classifier_dropout = ( classifier_dropout = (config.classifier_dropout
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob if config.classifier_dropout is not None else config.hidden_dropout_prob)
)
self.dropout = nn.Dropout(classifier_dropout) self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
......
import json
import logging
import os import os
import random import random
import h5py
import logging
import json
import time import time
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
import h5py
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from bert_dataset_provider import BertDatasetProviderInterface
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from bert_dataset_provider import BertDatasetProviderInterface
import colossalai.utils as utils import colossalai.utils as utils
# Workaround because python functions are not picklable # Workaround because python functions are not picklable
class WorkerInitObj(object): class WorkerInitObj(object):
def __init__(self, seed): def __init__(self, seed):
self.seed = seed self.seed = seed
...@@ -27,29 +28,25 @@ class WorkerInitObj(object): ...@@ -27,29 +28,25 @@ class WorkerInitObj(object):
random.seed(self.seed + id) random.seed(self.seed + id)
def create_pretraining_dataset(input_file, max_predictions_per_seq, def create_pretraining_dataset(input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init,
num_workers, train_batch_size, worker_init,
data_sampler): data_sampler):
train_data = pretraining_dataset( train_data = pretraining_dataset(input_file=input_file, max_predictions_per_seq=max_predictions_per_seq)
input_file=input_file, max_predictions_per_seq=max_predictions_per_seq)
train_dataloader = DataLoader(train_data, train_dataloader = DataLoader(train_data,
sampler=data_sampler(train_data), sampler=data_sampler(train_data),
batch_size=train_batch_size, batch_size=train_batch_size,
num_workers=num_workers, num_workers=num_workers,
worker_init_fn=worker_init, worker_init_fn=worker_init,
pin_memory=True pin_memory=True)
)
return train_dataloader, len(train_data) return train_dataloader, len(train_data)
class pretraining_dataset(Dataset): class pretraining_dataset(Dataset):
def __init__(self, input_file, max_predictions_per_seq): def __init__(self, input_file, max_predictions_per_seq):
self.input_file = input_file self.input_file = input_file
self.max_predictions_per_seq = max_predictions_per_seq self.max_predictions_per_seq = max_predictions_per_seq
f = h5py.File(input_file, "r") f = h5py.File(input_file, "r")
keys = [ keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions']
'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'
]
self.inputs = [np.asarray(f[key][:]) for key in keys] self.inputs = [np.asarray(f[key][:]) for key in keys]
f.close() f.close()
...@@ -59,21 +56,16 @@ class pretraining_dataset(Dataset): ...@@ -59,21 +56,16 @@ class pretraining_dataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
[ [input_ids, input_mask, segment_ids, masked_lm_labels] = [
input_ids, input_mask, segment_ids, masked_lm_labels torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy(
] = [ np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs)
torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else
torch.from_numpy(np.asarray(input[index].astype(np.int64)))
for indice, input in enumerate(self.inputs)
] ]
return [ return [input_ids, input_mask, segment_ids, masked_lm_labels]
input_ids, input_mask,
segment_ids, masked_lm_labels
]
class NvidiaBertDatasetProvider(BertDatasetProviderInterface): class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
def __init__(self, args, evaluate=False): def __init__(self, args, evaluate=False):
self.num_workers = args.num_workers self.num_workers = args.num_workers
self.max_seq_length = args.max_seq_length self.max_seq_length = args.max_seq_length
...@@ -85,22 +77,24 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface): ...@@ -85,22 +77,24 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
else: else:
self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu
self.logger = args.logger self.logger = args.logger
self.global_rank = dist.get_rank() self.global_rank = dist.get_rank()
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
# Initialize dataset files # Initialize dataset files
if not evaluate: if not evaluate:
self.dataset_files = [ self.dataset_files = [
os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if os.path.join(args.data_path_prefix, f)
os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f for f in os.listdir(args.data_path_prefix)
if os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f
] ]
else: else:
self.dataset_files = [ self.dataset_files = [
os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if os.path.join(args.eval_data_path_prefix, f)
os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f for f in os.listdir(args.eval_data_path_prefix)
if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f
] ]
self.dataset_files.sort() self.dataset_files.sort()
# random.shuffle(self.dataset_files) # random.shuffle(self.dataset_files)
self.num_files = len(self.dataset_files) self.num_files = len(self.dataset_files)
...@@ -114,9 +108,7 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface): ...@@ -114,9 +108,7 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
self.shuffle = True self.shuffle = True
if self.global_rank == 0: if self.global_rank == 0:
self.logger.info( self.logger.info(f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}")
f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}"
)
def get_shard(self, index): def get_shard(self, index):
start = time.time() start = time.time()
...@@ -130,9 +122,8 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface): ...@@ -130,9 +122,8 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
worker_init=self.worker_init, worker_init=self.worker_init,
data_sampler=self.data_sampler) data_sampler=self.data_sampler)
else: else:
self.train_dataloader, sample_count = self.dataset_future.result( self.train_dataloader, sample_count = self.dataset_future.result(timeout=None)
timeout=None)
self.logger.info( self.logger.info(
f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s." f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s."
) )
...@@ -145,11 +136,9 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface): ...@@ -145,11 +136,9 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
def prefetch_shard(self, index): def prefetch_shard(self, index):
self.data_file = self._get_shard_file(index) self.data_file = self._get_shard_file(index)
self.dataset_future = self.pool.submit( self.dataset_future = self.pool.submit(create_pretraining_dataset, self.data_file, self.max_predictions_per_seq,
create_pretraining_dataset, self.data_file, self.num_workers, self.train_micro_batch_size_per_gpu, self.worker_init,
self.max_predictions_per_seq, self.num_workers, self.data_sampler)
self.train_micro_batch_size_per_gpu, self.worker_init,
self.data_sampler)
def get_batch(self, batch_iter): def get_batch(self, batch_iter):
return batch_iter return batch_iter
...@@ -179,4 +168,3 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface): ...@@ -179,4 +168,3 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
indices = torch.randperm(self.num_files, generator=g).tolist() indices = torch.randperm(self.num_files, generator=g).tolist()
new_dataset = [self.dataset_files[i] for i in indices] new_dataset = [self.dataset_files[i] for i in indices]
self.dataset_files = new_dataset self.dataset_files = new_dataset
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment