Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
# helper function
def exists(val):
return val is not None
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# top k filtering
def top_k(logits, thres=0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float("-inf"))
probs.scatter_(1, ind, val)
return probs
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, max_seq_len=2048, pad_value=0):
super().__init__()
self.max_seq_len = max_seq_len
self.pad_value = pad_value
self.net = net
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token=None, temperature=1.0, filter_thres=0.9, **kwargs):
b, t, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
logits = self.net(out, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
is_eos_token = out == eos_token
if is_eos_token.any(dim=-1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
out = out.masked_fill(mask, self.pad_value)
break
out = out[:, t:]
return out
def forward(self, x, **kwargs):
x_inp, x_labels = x[:, :-1], x[:, 1:]
logits = self.net(x_inp, **kwargs)
return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels)
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, matmul, nn
# normalization
# they use layernorm without bias, something that pytorch does not offer
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# parallel with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
class ParallelResidual(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
return x + sum([fn(x) for fn in self.fns])
# rotary positional embedding
# https://arxiv.org/abs/2104.09864
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device)
#freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
#freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq)
freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j))
return torch.cat((freqs, freqs), dim=-1)
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# feedforward
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU
# https://arxiv.org/abs/2002.05202
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(inner_dim, dim, bias=False),
)
# attention
class Attention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8):
super().__init__()
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.heads = heads
self.scale = dim_head**-0.5
self.rotary_emb = RotaryEmbedding(dim_head)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# for caching causal mask and rotary embeddings
self.register_buffer("mask", None, persistent=False)
self.register_buffer("pos_emb", None, persistent=False)
def get_mask(self, n, device):
if self.mask is not None and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("position", pos_emb, persistent=False)
return pos_emb
def forward(self, x):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x)
# queries, keys, values
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# rotary embeddings
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
# scale
q = q * self.scale
b, h, i, d, j = q.size(0), q.size(1), q.size(2), q.size(3), k.size(1)
# similarity
#sim = einsum("b h i d, b j d -> b h i j", q, k)
sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2))
sim = sim.reshape(b, h, i, j)
# causal mask
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# attention
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
b_, h_, i_, j_, d_ = attn.size(0), attn.size(1), attn.size(2), attn.size(3), v.size(2)
# aggregate values
#out = einsum("b h i j, b j d -> b h i d", attn, v)
out = matmul(attn.reshape(b_, h_ * i_, j_), v)
out = out.reshape(b_, h_, i_, d_)
# merge heads
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
# transformer
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
net = nn.Sequential(
nn.Embedding(num_tokens, dim), *[
ParallelResidual(
Attention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
) for _ in range(depth)
], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False))
# they used embedding weight tied projection out to logits, not common, but works
net[-1].weight = net[0].weight
nn.init.normal_(net[0].weight, std=0.02)
return net
colossalai >= 0.1.12
torch >= 1.8.1
# distplan in ["colossalai", "pytorch"]
export DISTPAN="colossalai"
# The following options only valid when DISTPAN="colossalai"
export TPDEGREE=1
export GPUNUM=1
export PLACEMENT='cpu'
export USE_SHARD_INIT=False
export BATCH_SIZE=4
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train_new.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
\ No newline at end of file
import gzip
import random
import numpy as np
import torch
import torch.optim as optim
import tqdm
from packaging import version
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
# constants
NUM_BATCHES = int(1000)
GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--distplan",
type=str,
default='colossalai',
help="The distributed plan [colossalai, pytorch].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
type=bool,
default=False,
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="batch size per DP group of training.",
)
args = parser.parse_args()
return args
# helpers
def cycle(loader):
while True:
for data in loader:
yield data
def decode_token(token):
return str(chr(max(32, token)))
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'net.0' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'to_q' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'to_kv' in mn:
split_param_row_tp1d(param, pg) # row slice
elif 'to_out' in mn:
split_param_row_tp1d(param, pg) # row slice
elif '1.1' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif '1.2' in mn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error")
disable_existing_loggers()
colossalai.launch_from_torch(config={})
with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=args.batch_size))
val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size))
if args.distplan == "colossalai":
# instantiate GPT-like decoder model
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx:
model = PaLM(num_tokens=256, dim=512, depth=8)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
tensor_parallelize(model, pg)
model = gemini_zero_dpp(model, pg, args.placement)
#optimizer
#optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5)
else:
model = PaLM(num_tokens=256, dim=512, depth=8)
model = AutoregressiveWrapper(model, max_seq_len=2048)
model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# training
model.train()
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
if args.distplan == "colossalai":
optimizer.zero_grad()
loss = model(next(train_loader))
# loss.backward()
optimizer.backward(loss)
print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# optim.step()
# optim.zero_grad()
optimizer.step()
else:
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
# TODO
# if i % VALIDATE_EVERY == 0:
# model.eval()
# with torch.no_grad():
# loss = model(next(val_loader))
# print(f"validation loss: {loss.item()}")
# if i % GENERATE_EVERY == 0:
# model.eval()
# inp = random.choice(val_dataset)[:-1]
# prime = decode_tokens(inp)
# print(f"%s \n\n %s", (prime, "*" * 100))
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
# output_str = decode_tokens(sample[0])
# print(output_str)
\ No newline at end of file
# Introduction
This repo introduce how to pretrain a chinese roberta-large from scratch, including preprocessing, pretraining, finetune. The repo can help you quickly train a high-quality bert.
## 0. Prerequisite
- Install Colossal-AI
- Editing the port from /etc/ssh/sshd_config and /etc/ssh/ssh_config, every host expose the same ssh port of server and client. If you are a root user, you also set the **PermitRootLogin** from /etc/ssh/sshd_config to "yes"
- Ensure that each host can log in to each other without password. If you have n hosts, need to execute n<sup>2</sup> times
```
ssh-keygen
ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination
```
- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below.
```bash
192.168.2.1 GPU001
192.168.2.2 GPU002
192.168.2.3 GPU003
192.168.2.4 GPU004
192.168.2.5 GPU005
192.168.2.6 GPU006
192.168.2.7 GPU007
...
```
- restart ssh
```
service ssh restart
```
## 1. Corpus Preprocessing
```bash
cd preprocessing
```
following the `README.md`, preprocess original corpus to h5py+numpy
## 2. Pretrain
```bash
cd pretraining
```
following the `README.md`, load the h5py generated by preprocess of step 1 to pretrain the model
## 3. Finetune
The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transfomers from Hugging Face to finetune downstream application.
## Contributors
The repo is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution!
```
@misc{
title={A simple Chinese RoBERTa Example for Whole Word Masked},
author={Yehua Zhang, Chen Zhang},
year={2022}
}
```
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam
clip_grad_norm = 1.0
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam
# fp16 = dict(
# mode=AMP_TYPE.TORCH,
# )
# seed = 2
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
reduce_scatter_bucket_size_mb=25,
fp32_reduce_scatter=False,
tensor_placement_policy="cuda",
gradient_predivide_factor=1.0,
reuse_fp16_shard=False),
optimizer_config=dict(gpu_margin_mem_ratio=0.8,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32))
# gradient_accumulation = 4
clip_grad_norm = 1.0
optimizer = dict(
type=FusedAdam,
lr=0.00015,
weight_decay=1e-2,
)
# 64433
\ No newline at end of file
CXXFLAGS += -O3 -Wall -shared -std=c++14 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = mask
LIBEXT = $(shell python3-config --extension-suffix)
default: $(LIBNAME)$(LIBEXT)
%$(LIBEXT): %.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
# Data PreProcessing for chinese Whole Word Masked
<span id='all_catelogue'/>
## Catalogue:
* <a href='#introduction'>1. Introduction</a>
* <a href='#Quick Start Guide'>2. Quick Start Guide:</a>
* <a href='#Split Sentence'>2.1. Split Sentence</a>
* <a href='#Tokenizer & Whole Word Masked'>2.2.Tokenizer & Whole Word Masked</a>
<span id='introduction'/>
## 1. Introduction: <a href='#all_catelogue'>[Back to Top]</a>
This folder is used to preprocess chinese corpus with Whole Word Masked. You can obtain corpus from [WuDao](https://resource.wudaoai.cn/home?ind&name=WuDaoCorpora%202.0&id=1394901288847716352). Moreover, data preprocessing is flexible, and you can modify the code based on your needs, hardware or parallel framework(Open MPI, Spark, Dask).
<span id='Quick Start Guide'/>
## 2. Quick Start Guide: <a href='#all_catelogue'>[Back to Top]</a>
<span id='Split Sentence'/>
### 2.1. Split Sentence & Split data into multiple shard:
Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch.
In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.**
```python
python sentence_split.py --input_path /orginal_corpus --output_path /shard --shard 100
# This step takes a short time
```
* `--input_path`: all original corpus, e.g., /orginal_corpus/0.json /orginal_corpus/1.json ...
* `--output_path`: all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
* `--shard`: Number of shard, e.g., 10, 50, or 100
<summary><b>Input json:</b></summary>
```
[
{
"id": 0,
"title": "打篮球",
"content": "我今天去打篮球。不回来吃饭。"
}
{
"id": 1,
"title": "旅游",
"content": "我后天去旅游。下周请假。"
}
]
```
<summary><b>Output txt:</b></summary>
```
我今天去打篮球。
不回来吃饭。
]]
我后天去旅游。
下周请假。
```
<span id='Tokenizer & Whole Word Masked'/>
### 2.2. Tokenizer & Whole Word Masked:
```python
python tokenize_mask.py --input_path /shard --output_path /h5 --tokenizer_path /roberta --backend python
# This step is time consuming and is mainly spent on mask
```
**[optional but recommended]**: the C++ backend with `pybind11` can provide faster speed
```shell
make
```
* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ...
* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main)
* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed**
* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document
* `--worker`: number of process
<summary><b>Input txt:</b></summary>
```
我今天去打篮球。
不回来吃饭。
]]
我后天去旅游。
下周请假。
```
<summary><b>Output h5+numpy:</b></summary>
```
'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..],
...]
'input_mask': [[1,1,1,1,1,1,0,0..],
...]
'segment_ids': [[0,0,0,0,0,...],
...]
'masked_lm_positions': [[label1,-1,-1,label2,-1...],
...]
```
\ No newline at end of file
import torch
import os
from enum import IntEnum
from random import choice
import random
import collections
import time
import logging
import jieba
jieba.setLogLevel(logging.CRITICAL)
import re
import numpy as np
import mask
PAD = 0
MaskedLMInstance = collections.namedtuple("MaskedLMInstance",
["index", "label"])
def map_to_numpy(data):
return np.asarray(data)
class PreTrainingDataset():
def __init__(self,
tokenizer,
max_seq_length,
backend='python',
max_predictions_per_seq: int = 80,
do_whole_word_mask: bool = True):
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.masked_lm_prob = 0.15
self.backend = backend
self.do_whole_word_mask = do_whole_word_mask
self.max_predictions_per_seq = max_predictions_per_seq
self.vocab_words = list(tokenizer.vocab.keys())
self.rec = re.compile('[\u4E00-\u9FA5]')
self.whole_rec = re.compile('##[\u4E00-\u9FA5]')
self.mlm_p = 0.15
self.mlm_mask_p = 0.8
self.mlm_tamper_p = 0.05
self.mlm_maintain_p = 0.1
def tokenize(self, doc):
temp = []
for d in doc:
temp.append(self.tokenizer.tokenize(d))
return temp
def create_training_instance(self, instance):
is_next = 1
raw_text_list = self.get_new_segment(instance)
tokens_a = raw_text_list
assert len(tokens_a) == len(instance)
# tokens_a, tokens_b, is_next = instance.get_values()
# print(f'is_next label:{is_next}')
# Create mapper
tokens = []
original_tokens = []
segment_ids = []
tokens.append("[CLS]")
original_tokens.append('[CLS]')
segment_ids.append(0)
for index, token in enumerate(tokens_a):
tokens.append(token)
original_tokens.append(instance[index])
segment_ids.append(0)
tokens.append("[SEP]")
original_tokens.append('[SEP]')
segment_ids.append(0)
# for token in tokens_b:
# tokens.append(token)
# segment_ids.append(1)
# tokens.append("[SEP]")
# segment_ids.append(1)
# Get Masked LM predictions
if self.backend == 'c++':
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words,
self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob)
elif self.backend == 'python':
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
# Convert to Ids
input_ids = self.tokenizer.convert_tokens_to_ids(output_tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < self.max_seq_length:
input_ids.append(PAD)
segment_ids.append(PAD)
input_mask.append(PAD)
masked_lm_output.append(-1)
return ([
map_to_numpy(input_ids),
map_to_numpy(input_mask),
map_to_numpy(segment_ids),
map_to_numpy(masked_lm_output),
map_to_numpy([is_next])
])
def create_masked_lm_predictions(self, tokens):
cand_indexes = []
for i, token in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
# cand_indexes.append(i)
random.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(
self.max_predictions_per_seq,
max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indexes:
continue
covered_indexes.add(index)
masked_token = None
# 80% mask
if random.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% Keep Original
if random.random() < 0.5:
masked_token = tokens[index]
# 10% replace w/ random word
else:
masked_token = self.vocab_words[random.randint(
0,
len(self.vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(
MaskedLMInstance(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens)
for p in masked_lms:
masked_lm_output[p.index] = self.tokenizer.vocab[p.label]
return (output_tokens, masked_lm_output)
def get_new_segment(self, segment):
"""
输入一句话,返回一句经过处理的话: 为了支持中文全称mask,将被分开的词,将上特殊标记("#"),使得后续处理模块,能够知道哪些字是属于同一个词的。
:param segment: 一句话
:return: 一句处理过的话
"""
seq_cws = jieba.lcut(''.join(segment))
seq_cws_dict = {x: 1 for x in seq_cws}
new_segment = []
i = 0
while i < len(segment):
if len(self.rec.findall(segment[i])) == 0: # 不是中文的,原文加进去。
new_segment.append(segment[i])
i += 1
continue
has_add = False
for length in range(3, 0, -1):
if i + length > len(segment):
continue
if ''.join(segment[i: i+length]) in seq_cws_dict:
new_segment.append(segment[i])
for l in range(1, length):
new_segment.append('##' + segment[i+l])
i += length
has_add = True
break
if not has_add:
new_segment.append(segment[i])
i += 1
return new_segment
def create_whole_masked_lm_predictions(self, tokens):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # 去掉"##"
num_to_predict = min(self.max_predictions_per_seq,
max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if random.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if random.random() < 0.5:
masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # 去掉"##"
# 10% of the time, replace with random word
else:
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens)
for p in masked_lms:
masked_lm_output[p.index] = self.tokenizer.vocab[p.label]
return (output_tokens, masked_lm_output)
#include <algorithm>
#include <iostream>
#include <limits>
#include <math.h>
#include <stdexcept>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <random>
#include <vector>
#include <string>
#include <pybind11/stl.h>
#include <chrono>
#include <tuple>
#include <unordered_set>
#include <unordered_map>
namespace py = pybind11;
const int32_t LONG_SENTENCE_LEN = 512;
struct MaskedLMInstance {
int index;
std::string label;
MaskedLMInstance(int index, std::string label) {
this->index = index;
this->label = label;
}
};
auto get_new_segment(std::vector<std::string> segment, std::vector<std::string> segment_jieba, const std::vector<bool> chinese_vocab) { // const std::unordered_set<std::string> &chinese_vocab
std::unordered_set<std::string> seq_cws_dict;
for (auto word : segment_jieba) {
seq_cws_dict.insert(word);
}
int i = 0;
std::vector<std::string> new_segment;
int segment_size = segment.size();
while (i < segment_size) {
if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end()
new_segment.emplace_back(segment[i]);
i += 1;
continue;
}
bool has_add = false;
for (int length = 3; length >= 1; length--) {
if (i + length > segment_size) {
continue;
}
std::string chinese_word = "";
for (int j = i; j < i + length; j++) {
chinese_word += segment[j];
}
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
new_segment.emplace_back(segment[i]);
for (int j = i + 1; j < i + length; j++) {
new_segment.emplace_back("##" + segment[j]);
}
i += length;
has_add = true;
break;
}
}
if (!has_add) {
new_segment.emplace_back(segment[i]);
i += 1;
}
}
return new_segment;
}
bool startsWith(const std::string& s, const std::string& sub) {
return s.find(sub) == 0 ? true : false;
}
auto create_whole_masked_lm_predictions(std::vector<std::string> &tokens,
const std::vector<std::string> &original_tokens,
const std::vector<std::string> &vocab_words,
std::map<std::string, int> &vocab,
const int max_predictions_per_seq,
const double masked_lm_prob) {
// for (auto item : vocab) {
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
// << "value=" << std::string(py::str(item.second)) << std::endl;
// }
std::vector<std::vector<int> > cand_indexes;
std::vector<int> cand_temp;
int tokens_size = tokens.size();
std::string prefix = "##";
bool do_whole_masked = true;
for (int i = 0; i < tokens_size; i++) {
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
continue;
}
if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) {
cand_temp.emplace_back(i);
}
else {
if (cand_temp.size() > 0) {
cand_indexes.emplace_back(cand_temp);
}
cand_temp.clear();
cand_temp.emplace_back(i);
}
}
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed));
// for (auto i : cand_indexes) {
// for (auto j : i) {
// std::cout << tokens[j] << " ";
// }
// std::cout << std::endl;
// }
// for (auto i : output_tokens) {
// std::cout << i;
// }
// std::cout << std::endl;
int num_to_predict = std::min(max_predictions_per_seq,
std::max(1, int(tokens_size * masked_lm_prob)));
// std::cout << num_to_predict << std::endl;
std::set<int> covered_indexes;
std::vector<int> masked_lm_output(tokens_size, -1);
int vocab_words_len = vocab_words.size();
std::default_random_engine e(seed);
std::uniform_real_distribution<double> u1(0.0, 1.0);
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
int mask_cnt = 0;
std::vector<std::string> output_tokens;
output_tokens = original_tokens;
for (auto index_set : cand_indexes) {
if (mask_cnt > num_to_predict) {
break;
}
int index_set_size = index_set.size();
if (mask_cnt + index_set_size > num_to_predict) {
continue;
}
bool is_any_index_covered = false;
for (auto index : index_set) {
if (covered_indexes.find(index) != covered_indexes.end()) {
is_any_index_covered = true;
break;
}
}
if (is_any_index_covered) {
continue;
}
for (auto index : index_set) {
covered_indexes.insert(index);
std::string masked_token;
if (u1(e) < 0.8) {
masked_token = "[MASK]";
}
else {
if (u1(e) < 0.5) {
masked_token = output_tokens[index];
}
else {
int random_index = u2(e);
masked_token = vocab_words[random_index];
}
}
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
masked_lm_output[index] = vocab[output_tokens[index]];
output_tokens[index] = masked_token;
mask_cnt++;
}
}
// for (auto p : masked_lms) {
// masked_lm_output[p.index] = vocab[p.label];
// }
return std::make_tuple(output_tokens, masked_lm_output);
}
PYBIND11_MODULE(mask, m) {
m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions);
m.def("get_new_segment", &get_new_segment);
}
import multiprocessing
import os
import re
from tqdm import tqdm
from typing import List
import json
import time
import argparse
import functools
def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
"""
Args:
document:
flag: Type:str, "all" 中英文标点分句,"zh" 中文标点分句,"en" 英文标点分句
limit: 默认单句最大长度为510个字符
Returns: Type:list
"""
sent_list = []
try:
if flag == "zh":
document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document) # 单字符断句符
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document) # 特殊引号
elif flag == "en":
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document) # 英文单字符断句符
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n', document) # 特殊引号
else:
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document) # 单字符断句符
document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
document) # 特殊引号
sent_list_ori = document.splitlines()
for sent in sent_list_ori:
sent = sent.strip()
if not sent:
continue
elif len(sent) <= 2:
continue
else:
while len(sent) > limit:
temp = sent[0:limit]
sent_list.append(temp)
sent = sent[limit:]
sent_list.append(sent)
except:
sent_list.clear()
sent_list.append(document)
return sent_list
def get_sent(output_path,
input_path,
fin_list=[], host=-1, seq_len=512) -> None:
workers = 32
if input_path[-1] == '/':
input_path = input_path[:-1]
cur_path = os.path.join(output_path, str(host) + '.txt')
new_split_sentence = functools.partial(split_sentence, limit=seq_len-2)
with open(cur_path, 'w', encoding='utf-8') as f:
for fi, fin_path in enumerate(fin_list):
if not os.path.exists(os.path.join(input_path, fin_path[0])):
continue
if '.json' not in fin_path[0]:
continue
print("Processing ", fin_path[0], " ", fi)
with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
f_data = [l['content'] for l in json.load(fin)]
pool = multiprocessing.Pool(workers)
all_sent = pool.imap_unordered(new_split_sentence, f_data, 32)
pool.close()
print('finished..')
cnt = 0
for d in tqdm(all_sent):
for i in d:
f.write(i.strip() + '\n')
f.write(']]' + '\n')
cnt += 1
# if cnt >= 2:
# exit()
def getFileSize(filepath, shard):
all_data = []
for i in os.listdir(filepath):
all_data.append(os.path.join(filepath, i))
all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data])
ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
ans = sorted(ans, key=lambda x: x[1], reverse=True)
per_size = all_size / shard
real_shard = []
temp = []
accu_size = 0
for i in ans:
accu_size += i[1]
temp.append(i)
if accu_size > per_size:
real_shard.append(temp)
accu_size = 0
temp = []
if len(temp) > 0:
real_shard.append(temp)
return real_shard
def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
import socket
host = int(socket.gethostname().split(server_name)[-1])
fin_list = real_shard[server_num * base + host - 1]
print(fin_list)
print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
return fin_list, host
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100')
parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus')
parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
args = parser.parse_args()
server_num = args.server_num
seq_len = args.seq_len
shard = args.shard
input_path = args.input_path
output_path = args.output_path
real_shard = getFileSize(input_path, shard)
start = time.time()
for index, shard in enumerate(real_shard):
get_sent(output_path,
input_path,
fin_list=shard,
host=index,
seq_len=seq_len)
print(f'cost {str(time.time() - start)}')
# if you have multiple server, you can use code below or modify code to openmpi
# for i in range(len(real_shard) // server_num + 1):
# fin_list, host = get_start_end(real_shard, i)
# start = time.time()
# get_sent(output_path,
# input_path,
# fin_list=fin_list, host= 10 * i + host - 1)
# print(f'cost {str(time.time() - start)}')
import time
import os
import psutil
import h5py
import socket
import argparse
import numpy as np
import multiprocessing
from tqdm import tqdm
from random import shuffle
from transformers import AutoTokenizer
from get_mask import PreTrainingDataset
def get_raw_instance(document, max_sequence_length=512):
"""
获取初步的训练实例,将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回。
:param document: 一整段
:param max_sequence_length:
:return: a list. each element is a sequence of text
"""
# document = self.documents[index]
max_sequence_length_allowed = max_sequence_length - 2
# document = [seq for seq in document if len(seq)<max_sequence_length_allowed]
sizes = [len(seq) for seq in document]
result_list = []
curr_seq = [] # 当前处理的序列
sz_idx = 0
while sz_idx < len(sizes):
# 当前句子加上新的句子,如果长度小于最大限制,则合并当前句子和新句子;否则即超过了最大限制,那么做为一个新的序列加到目标列表中
if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0:
curr_seq += document[sz_idx]
sz_idx += 1
elif sizes[sz_idx] >= max_sequence_length_allowed:
if len(curr_seq) > 0:
result_list.append(curr_seq)
curr_seq = []
result_list.append(document[sz_idx][ : max_sequence_length_allowed])
sz_idx += 1
else:
result_list.append(curr_seq)
curr_seq = []
# 对最后一个序列进行处理,如果太短的话,丢弃掉。
if len(curr_seq) > max_sequence_length_allowed / 2: # /2
result_list.append(curr_seq)
# # 计算总共可以得到多少份
# num_instance=int(len(big_list)/max_sequence_length_allowed)+1
# print("num_instance:",num_instance)
# # 切分成多份,添加到列表中
# result_list=[]
# for j in range(num_instance):
# index=j*max_sequence_length_allowed
# end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1
# result_list.append(big_list[index:end_index])
return result_list
def split_numpy_chunk(path, tokenizer, pretrain_data, host):
documents = []
instances = []
s = time.time()
with open(path, encoding='utf-8') as fd:
document = []
for i, line in enumerate(tqdm(fd)):
line = line.strip()
# document = line
# if len(document.split("<sep>")) <= 3:
# continue
if len(line
) > 0 and line[:2] == "]]": # This is end of document
documents.append(document)
document = []
elif len(line) >= 2:
document.append(line)
if len(document) > 0:
documents.append(document)
print('read_file ', time.time() - s)
# documents = [x for x in documents if x]
# print(len(documents))
# print(len(documents[0]))
# print(documents[0][0:10])
from typing import List
import multiprocessing
ans = []
for docs in tqdm(documents):
ans.append(pretrain_data.tokenize(docs))
print(time.time() - s)
del documents
instances = []
for a in tqdm(ans):
raw_ins = get_raw_instance(a)
instances.extend(raw_ins)
del ans
print('len instance', len(instances))
sen_num = len(instances)
seq_len = 512
input_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
input_mask = np.zeros([sen_num, seq_len], dtype=np.int32)
segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32)
for index, ins in tqdm(enumerate(instances)):
mask_dict = pretrain_data.create_training_instance(ins)
input_ids[index] = mask_dict[0]
input_mask[index] = mask_dict[1]
segment_ids[index] = mask_dict[2]
masked_lm_output[index] = mask_dict[3]
with h5py.File(f'/output/{host}.h5', 'w') as hf:
hf.create_dataset("input_ids", data=input_ids)
hf.create_dataset("input_mask", data=input_ids)
hf.create_dataset("segment_ids", data=segment_ids)
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
del instances
def split_numpy_chunk_pool(input_path,
output_path,
pretrain_data,
worker,
dupe_factor,
seq_len,
file_name):
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
print(f'{file_name}.h5 exists')
return
documents = []
instances = []
s = time.time()
with open(input_path, 'r', encoding='utf-8') as fd:
document = []
for i, line in enumerate(tqdm(fd)):
line = line.strip()
if len(line
) > 0 and line[:2] == "]]": # This is end of document
documents.append(document)
document = []
elif len(line) >= 2:
document.append(line)
if len(document) > 0:
documents.append(document)
print(f'read_file cost {time.time() - s}, length is {len(documents)}')
ans = []
s = time.time()
pool = multiprocessing.Pool(worker)
encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100)
for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'):
ans.append(res)
pool.close()
print((time.time() - s) / 60)
del documents
instances = []
for a in tqdm(ans, colour='MAGENTA'):
raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
instances.extend(raw_ins)
del ans
print('len instance', len(instances))
new_instances = []
for _ in range(dupe_factor):
for ins in instances:
new_instances.append(ins)
shuffle(new_instances)
instances = new_instances
print('after dupe_factor, len instance', len(instances))
sentence_num = len(instances)
input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32)
segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32)
s = time.time()
pool = multiprocessing.Pool(worker)
encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32)
for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'):
input_ids[index] = mask_dict[0]
input_mask[index] = mask_dict[1]
segment_ids[index] = mask_dict[2]
masked_lm_output[index] = mask_dict[3]
pool.close()
print((time.time() - s) / 60)
with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf:
hf.create_dataset("input_ids", data=input_ids)
hf.create_dataset("input_mask", data=input_mask)
hf.create_dataset("segment_ids", data=segment_ids)
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
del instances
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100')
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively')
parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document')
parser.add_argument('--worker', type=int, default=32, help='number of process')
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
pretrain_data = PreTrainingDataset(tokenizer,
args.seq_len,
args.backend,
max_predictions_per_seq=args.max_predictions_per_seq)
data_len = len(os.listdir(args.input_path))
for i in range(data_len):
input_path = os.path.join(args.input_path, f'{i}.txt')
if os.path.exists(input_path):
start = time.time()
print(f'process {input_path}')
split_numpy_chunk_pool(input_path,
args.output_path,
pretrain_data,
args.worker,
args.dupe_factor,
args.seq_len,
i)
end_ = time.time()
print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
print(f'has cost {(end_ - start) / 60}')
print('-' * 100)
print('')
# if you have multiple server, you can use code below or modify code to openmpi
# host = int(socket.gethostname().split('GPU')[-1])
# for i in range(data_len // args.server_num + 1):
# h = args.server_num * i + host - 1
# input_path = os.path.join(args.input_path, f'{h}.txt')
# if os.path.exists(input_path):
# start = time.time()
# print(f'I am server {host}, process {input_path}')
# split_numpy_chunk_pool(input_path,
# args.output_path,
# pretrain_data,
# args.worker,
# args.dupe_factor,
# args.seq_len,
# h)
# end_ = time.time()
# print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
# print(f'has cost {(end_ - start) / 60}')
# print('-' * 100)
# print('')
# Pretraining
1. Pretraining roberta through running the script below. Detailed parameter descriptions can be found in the arguments.py. `data_path_prefix` is absolute path specifies output of preprocessing. **You have to modify the *hostfile* according to your cluster.**
```bash
bash run_pretrain.sh
```
* `--hostfile`: servers' host name from /etc/hosts
* `--include`: servers which will be used
* `--nproc_per_node`: number of process(GPU) from each server
* `--data_path_prefix`: absolute location of train data, e.g., /h5/0.h5
* `--eval_data_path_prefix`: absolute location of eval data
* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json, e.g./tokenizer/tokenizer.json
* `--bert_config`: config.json which represent model
* `--mlm`: model type of backbone, bert or deberta_v2
2. if resume training from earylier checkpoint, run the script below.
```shell
bash run_pretrain_resume.sh
```
* `--resume_train`: whether to resume training
* `--load_pretrain_model`: absolute path which contains model checkpoint
* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint
import colossalai
from numpy import require
__all__ = ['parse_args']
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
'--lr',
type=float,
required=True,
help='initial learning rate')
parser.add_argument(
'--epoch',
type=int,
required=True,
help='number of epoch')
parser.add_argument(
'--data_path_prefix',
type=str,
required=True,
help="location of the train data corpus")
parser.add_argument(
'--eval_data_path_prefix',
type=str,
required=True,
help='location of the evaluation data corpus')
parser.add_argument(
'--tokenizer_path',
type=str,
required=True,
help='location of the tokenizer')
parser.add_argument(
'--max_seq_length',
type=int,
default=512,
help='sequence length')
parser.add_argument(
'--refresh_bucket_size',
type=int,
default=1,
help=
"This param makes sure that a certain task is repeated for this time steps to \
optimise on the back propogation speed with APEX's DistributedDataParallel")
parser.add_argument(
"--max_predictions_per_seq",
"--max_pred",
default=80,
type=int,
help=
"The maximum number of masked tokens in a sequence to be predicted.")
parser.add_argument(
"--gradient_accumulation_steps",
default=1,
type=int,
help="accumulation_steps")
parser.add_argument(
"--train_micro_batch_size_per_gpu",
default=2,
type=int,
required=True,
help="train batch size")
parser.add_argument(
"--eval_micro_batch_size_per_gpu",
default=2,
type=int,
required=True,
help="eval batch size")
parser.add_argument(
"--num_workers",
default=8,
type=int,
help="")
parser.add_argument(
"--async_worker",
action='store_true',
help="")
parser.add_argument(
"--bert_config",
required=True,
type=str,
help="location of config.json")
parser.add_argument(
"--wandb",
action='store_true',
help="use wandb to watch model")
parser.add_argument(
"--wandb_project_name",
default='roberta',
help="wandb project name")
parser.add_argument(
"--log_interval",
default=100,
type=int,
help="report interval")
parser.add_argument(
"--log_path",
type=str,
required=True,
help="log file which records train step")
parser.add_argument(
"--tensorboard_path",
type=str,
required=True,
help="location of tensorboard file")
parser.add_argument(
"--colossal_config",
type=str,
required=True,
help="colossal config, which contains zero config and so on")
parser.add_argument(
"--ckpt_path",
type=str,
required=True,
help="location of saving checkpoint, which contains model and optimizer")
parser.add_argument(
'--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument(
'--vscode_debug',
action='store_true',
help="use vscode to debug")
parser.add_argument(
'--load_pretrain_model',
default='',
type=str,
help="location of model's checkpoin")
parser.add_argument(
'--load_optimizer_lr',
default='',
type=str,
help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
parser.add_argument(
'--resume_train',
action='store_true',
help="whether resume training from a early checkpoint")
parser.add_argument(
'--mlm',
default='bert',
type=str,
help="model type, bert or deberta")
parser.add_argument(
'--checkpoint_activations',
action='store_true',
help="whether to use gradient checkpointing")
args = parser.parse_args()
return args
class BertDatasetProviderInterface:
def get_shard(self, index, shuffle=True):
raise NotImplementedError
def release_shard(self, index):
raise NotImplementedError
def prefetch_shard(self, index):
raise NotImplementedError
def get_batch(self, batch_iter):
raise NotImplementedError
def prefetch_batch(self):
raise NotImplementedError
import os
import math
import torch
from tqdm import tqdm
from utils.global_vars import get_timers, get_tensorboard_writer
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
def evaluate(engine, args, logger, global_step):
evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True)
start_shard = 0
engine.eval()
timers = get_timers()
eval_step = 0
eval_loss = 0
cur_loss = 0
world_size = torch.distributed.get_world_size()
with torch.no_grad():
for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))):
timers('eval_shard_time').start()
dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard)
# evaluate_dataset_provider.prefetch_shard(shard + 1)
if torch.distributed.get_rank() == 0:
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1)
else:
iterator_data = enumerate(dataset_iterator)
for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1):
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
eval_step += 1
input_ids = batch_data[0].cuda()
attention_mask = batch_data[1].cuda()
token_type_ids = batch_data[2].cuda()
mlm_label = batch_data[3].cuda()
# nsp_label = batch_data[5].cuda()
output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
loss = engine.criterion(output.logits, mlm_label)#prediction_scores
evaluate_dataset_provider.prefetch_batch()
eval_loss += loss.float().item()
cur_loss = eval_loss / eval_step
elapsed_time = timers("eval_shard_time").elapsed()
elapsed_time_per_iteration = elapsed_time / eval_step
ppl = math.exp(cur_loss)
if args.wandb and torch.distributed.get_rank() == 0:
tensorboard_log = get_tensorboard_writer()
tensorboard_log.log_eval({
'loss': cur_loss,
'ppl': ppl,
'mins_batch': elapsed_time_per_iteration
}, global_step)
eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}'
logger.info(eval_log_str)
logger.info('-' * 100)
logger.info('')
evaluate_dataset_provider.release_shard()
engine.train()
return cur_loss
GPU001
GPU002
GPU003
GPU004
GPU005
GPU006
GPU007
GPU008
GPU009
GPU010
import torch
__all__ = ['LossForPretraining']
class LossForPretraining(torch.nn.Module):
def __init__(self, vocab_size):
super(LossForPretraining, self).__init__()
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
self.vocab_size = vocab_size
def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None):
masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
# next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
total_loss = masked_lm_loss #+ next_sentence_loss
return total_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment