Commit a1c29028 authored by zhangqha's avatar zhangqha
Browse files

update uni-fold

parents
Pipeline #183 canceled with stages
## A simple BERT example
1. download data `wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip & unzip wikitext-2-v1.zip`
2. run `python preprocess.py ./wikitext-2/wiki.train.tokens ./train.lmdb`
3. run `python preprocess.py ./wikitext-2/wiki.valid.tokens ./valid.lmdb`
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
import os
import sys
import pickle
import lmdb
def wirte_to_lmbd(filename, outfilename):
try:
os.remove(outfilename)
except:
pass
env_new = lmdb.open(
outfilename,
subdir=False,
readonly=False,
lock=False,
readahead=False,
meminit=False,
max_readers=1,
map_size=int(100e9),
)
txn_write = env_new.begin(write = True)
with open(filename, 'r') as input:
i = 0
for line in input.readlines():
line = line.strip()
if line:
txn_write.put(f'{i}'.encode("ascii"), pickle.dumps(line))
i += 1
print('process {} lines'.format(i))
txn_write.commit()
env_new.close()
if __name__ == '__main__':
wirte_to_lmbd(sys.argv[1], sys.argv[2])
\ No newline at end of file
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from unicore import utils
from unicore.models import BaseUnicoreModel, register_model, register_model_architecture
from unicore.modules import LayerNorm, TransformerEncoder, init_bert_params
logger = logging.getLogger(__name__)
@register_model("bert")
class BertModel(BaseUnicoreModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--encoder-layers", type=int, metavar="L", help="num encoder layers"
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="H",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="F",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="A",
help="num encoder attention heads",
)
parser.add_argument(
"--activation-fn",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--pooler-activation-fn",
choices=utils.get_available_activation_fns(),
help="activation function to use for pooler layer",
)
parser.add_argument(
"--emb-dropout", type=float, metavar="D", help="dropout probability for embeddings"
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN",
)
parser.add_argument(
"--pooler-dropout",
type=float,
metavar="D",
help="dropout probability in the masked_lm pooler layers",
)
parser.add_argument(
"--max-seq-len", type=int, help="number of positional embeddings to learn"
)
parser.add_argument(
"--post-ln", type=bool, help="use post layernorm or pre layernorm"
)
def __init__(self, args, dictionary):
super().__init__()
base_architecture(args)
self.args = args
self.padding_idx = dictionary.pad()
self.embed_tokens = nn.Embedding(len(dictionary), args.encoder_embed_dim, self.padding_idx)
self.embed_positions = nn.Embedding(args.max_seq_len, args.encoder_embed_dim)
self.sentence_encoder = TransformerEncoder(
encoder_layers=args.encoder_layers,
embed_dim=args.encoder_embed_dim,
ffn_embed_dim=args.encoder_ffn_embed_dim,
attention_heads=args.encoder_attention_heads,
emb_dropout=args.emb_dropout,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
max_seq_len=args.max_seq_len,
activation_fn=args.activation_fn,
rel_pos=True,
rel_pos_bins=32,
max_rel_pos=128,
post_ln=args.post_ln,
)
self.lm_head = BertLMHead(embed_dim=args.encoder_embed_dim,
output_dim=len(dictionary),
activation_fn=args.activation_fn,
weight=self.embed_tokens.weight,
)
self.classification_heads = nn.ModuleDict()
self.apply(init_bert_params)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
return cls(args, task.dictionary)
def forward(
self,
src_tokens,
masked_tokens,
features_only=False,
classification_head_name=None,
**kwargs
):
if classification_head_name is not None:
features_only = True
padding_mask = src_tokens.eq(self.padding_idx)
if not padding_mask.any():
padding_mask = None
x = self.embed_tokens(src_tokens)
x += self.embed_positions.weight[:src_tokens.size(1), :]
x = self.sentence_encoder(x, padding_mask=padding_mask)
if not features_only:
x = self.lm_head(x, masked_tokens)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = BertClassificationHead(
input_dim=self.args.encoder_embed_dim,
inner_dim=inner_dim or self.args.encoder_embed_dim,
num_classes=num_classes,
activation_fn=self.args.pooler_activation_fn,
pooler_dropout=self.args.pooler_dropout,
)
class BertLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.layer_norm = LayerNorm(embed_dim)
if weight is None:
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the masked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias
return x
class BertClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim,
inner_dim,
num_classes,
activation_fn,
pooler_dropout,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes), q_noise, qn_block_size
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
@register_model_architecture("bert", "bert")
def base_architecture(args):
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.dropout = getattr(args, "dropout", 0.1)
args.emb_dropout = getattr(args, "emb_dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
args.max_seq_len = getattr(args, "max_seq_len", 512)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.post_ln = getattr(args, "post_ln", True)
@register_model_architecture("bert", "bert_base")
def bert_base_architecture(args):
base_architecture(args)
@register_model_architecture("bert", "bert_large")
def bert_large_architecture(args):
args.encoder_layers = getattr(args, "encoder_layers", 24)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)
@register_model_architecture("bert", "xlm")
def xlm_architecture(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1280 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)
\ No newline at end of file
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import contextlib
from typing import Optional
import numpy as np
from unicore.data import (
Dictionary,
MaskTokensDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
LMDBDataset,
PrependTokenDataset,
RightPadDataset,
SortDataset,
BertTokenizeDataset,
data_utils,
)
from unicore.tasks import UnicoreTask, register_task
logger = logging.getLogger(__name__)
@register_task("bert")
class BertTask(UnicoreTask):
"""Task for training masked language models (e.g., BERT)."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"data",
help="colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner",
)
parser.add_argument(
"--mask-prob",
default=0.15,
type=float,
help="probability of replacing a token with mask",
)
parser.add_argument(
"--leave-unmasked-prob",
default=0.1,
type=float,
help="probability that a masked token is unmasked",
)
parser.add_argument(
"--random-token-prob",
default=0.1,
type=float,
help="probability of replacing a token with a random token",
)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
# add mask token
self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True)
@classmethod
def setup_task(cls, args, **kwargs):
dictionary = Dictionary.load(os.path.join(args.data, "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
split_path = os.path.join(self.args.data, split + '.lmdb')
dict_path = os.path.join(self.args.data, "dict.txt")
dataset = LMDBDataset(split_path)
dataset = BertTokenizeDataset(dataset, dict_path, max_seq_len=self.args.max_seq_len)
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
dataset,
self.dictionary,
pad_idx=self.dictionary.pad(),
mask_idx=self.mask_idx,
seed=self.args.seed,
mask_prob=self.args.mask_prob,
leave_unmasked_prob=self.args.leave_unmasked_prob,
random_token_prob=self.args.random_token_prob,
)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(src_dataset))
self.datasets[split] = SortDataset(
NestedDictionaryDataset(
{
"net_input": {
"src_tokens": RightPadDataset(
src_dataset,
pad_idx=self.dictionary.pad(),
)
},
"target": RightPadDataset(
tgt_dataset,
pad_idx=self.dictionary.pad(),
),
},
),
sort_order=[
shuffle
],
)
def build_model(self, args):
from unicore import models
model = models.build_model(args, self)
return model
[ -z "${MASTER_PORT}" ] && MASTER_PORT=10086
[ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l)
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=1
python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) ./example_data --user-dir . --valid-subset valid \
--num-workers 0 --ddp-backend=c10d \
--task bert --loss masked_lm --arch bert_base \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --clip-norm 1.0 \
--lr-scheduler polynomial_decay --lr 1e-4 --warmup-updates 100 --total-num-update 10000 --batch-size 4 \
--update-freq 1 --seed 1 \
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --tensorboard-logdir ./tsb/ \
--max-update 10000 --log-interval 100 --log-format simple \
--save-interval-updates 5000 --validate-interval-updates 5000 --keep-interval-updates 30 --no-epoch-checkpoints \
--save-dir ./save
iopath
lmdb
ml_collections
numpy
scipy
tensorboardX
tqdm
tokenizers
#!/usr/bin/env python3 -u
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.utils import cpp_extension
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
import os
import subprocess
import sys
from setuptools import find_packages, setup
if sys.version_info < (3, 7):
sys.exit("Sorry, Python >= 3.7 is required for unicore.")
def write_version_py():
with open(os.path.join("unicore", "version.txt")) as f:
version = f.read().strip()
# write version info to unicore/version.py
with open(os.path.join("unicore", "version.py"), "w") as f:
f.write('__version__ = "{}"\n'.format(version))
return version
version = write_version_py()
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/hipcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
if not torch.cuda.is_available():
print('\nWarning: Torch did not find available GPUs on this system.\n',
'If your intention is to cross-compile, this is not an error.\n'
'By default, it will cross-compile for Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
'If you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5;8.0"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if not (TORCH_MAJOR >= 1 and TORCH_MINOR >= 4):
raise RuntimeError("Requires Pytorch 1.4 or newer.\n" +
"The latest stable release can be obtained from https://pytorch.org/")
cmdclass = {}
ext_modules = []
extras = {}
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/hipcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " +
"not match the version used to compile Pytorch binaries. " +
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda))
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("Nvcc was not found. Are you sure your environment has hipcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide hipcc.")
#check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)
generator_flag = []
#generator_flag += [('HIP_DIFF',None)]
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
generator_flag = ['-DOLD_GENERATOR']
ext_modules.append(
CUDAExtension(name='unicore_fused_rounding',
sources=['csrc/rounding/interface_hip.cpp',
'csrc/rounding/fp32_to_bf16.hip'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + generator_flag,
'hipcc':['-O3', '--use_fast_math',
'-gencode', 'arch=gfx906',
'-gencode', 'arch=gfx906',
# '-U__CUDA_NO_HALF_OPERATORS__',
# '-U__CUDA_NO_BFLOAT16_OPERATORS__',
# '-U__CUDA_NO_HALF_CONVERSIONS__',
# '-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'] + generator_flag}))
ext_modules.append(
CUDAExtension(name='unicore_fused_multi_tensor',
sources=['csrc/multi_tensor/interface.cpp',
# 'csrc/multi_tensor/multi_tensor_l2norm_kernel.cu'],
'csrc/multi_tensor/multi_tensor_l2norm_kernel.hip'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'],
'hipcc':['-O3', '--use_fast_math',
'-gencode', 'arch=gfx906',
'-gencode', 'arch=gfx906',
# '-U__CUDA_NO_HALF_OPERATORS__',
# '-U__CUDA_NO_BFLOAT16_OPERATORS__',
# '-U__CUDA_NO_HALF_CONVERSIONS__',
# '-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda']
}))
ext_modules.append(
CUDAExtension(name='unicore_fused_adam',
sources=['csrc/adam/interface.cpp',
'csrc/adam/adam_kernel.hip'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'],
'hipcc':['-O3', '--use_fast_math']}))
ext_modules.append(
CUDAExtension(name='unicore_fused_softmax_dropout',
sources=['csrc/softmax_dropout/interface_hip.cpp',
'csrc/softmax_dropout/softmax_dropout_kernel.hip'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + generator_flag,
'hipcc':['-O3', '--use_fast_math',
'-gencode', 'arch=gfx906',
'-gencode', 'arch=gfx906',
#'-U__CUDA_NO_HALF_OPERATORS__',
#'-U__CUDA_NO_BFLOAT16_OPERATORS__',
#'-U__CUDA_NO_HALF_CONVERSIONS__',
#'-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'] + generator_flag}))
ext_modules.append(
CUDAExtension(name='unicore_fused_layernorm',
sources=['csrc/layernorm/interface.cpp',
'csrc/layernorm/layernorm.hip'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + generator_flag,
'hipcc':['-O3', '--use_fast_math',
'-gencode', 'arch=gfx906',
'-gencode', 'arch=gfx906',
#'-U__CUDA_NO_HALF_OPERATORS__',
#'-U__CUDA_NO_BFLOAT16_OPERATORS__',
#'-U__CUDA_NO_HALF_CONVERSIONS__',
#'-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'] + generator_flag}))
ext_modules.append(
CUDAExtension(name='unicore_fused_layernorm_backward_gamma_beta',
sources=['csrc/layernorm/interface_gamma_beta.cpp',
'csrc/layernorm/layernorm_backward.hip'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + generator_flag,
'hipcc':['-O3', '--use_fast_math', '-maxrregcount=50',
'-gencode', 'arch=gfx906',
'-gencode', 'arch=gfx906',
#'-U__CUDA_NO_HALF_OPERATORS__',
#'-U__CUDA_NO_BFLOAT16_OPERATORS__',
#'-U__CUDA_NO_HALF_CONVERSIONS__',
#'-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'] + generator_flag}))
setup(
name="unicore",
version=version,
description="DP Technology's Core AI Framework",
url="https://github.com/dptech-corp/unicore",
classifiers=[
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
setup_requires=[
"setuptools>=18.0",
],
install_requires=[
'numpy; python_version>="3.7"',
"lmdb",
#"torch>=1.10.0",
"tqdm",
"ml_collections",
"scipy",
"tensorboardX",
"tokenizers",
],
packages=find_packages(
exclude=[
'build',
'csrc',
"examples",
"examples.*",
"scripts",
"scripts.*",
"tests",
"tests.*",
]
),
ext_modules=ext_modules,
cmdclass=cmdclass,
extras_require=extras,
entry_points={
"console_scripts": [
"unicore-train = unicore_cli.train:cli_main",
],
},
zip_safe=False,
)
import torch
import torch.nn.functional as F
from unicore.modules import softmax_dropout
def gen_attn_mask(mask, neg_inf):
assert neg_inf < -1e4
attn_mask = torch.zeros_like(mask)
attn_mask[mask == 0] = neg_inf
return attn_mask
def normal_softmax(a, mask, bias):
return F.softmax(a + mask + bias, dim=-1)
def fused_softmax(a, mask, bias):
return softmax_dropout(a, 0, True, mask=mask, bias=bias)
def wrap_forward_backward(func, a1, mask, bias1):
a = a1.clone()
bias = bias1.clone()
a.requires_grad = True
bias.requires_grad = True
output = func(a, mask, bias)
o = output.float().sum()
o.backward()
return output, a.grad, bias.grad
def check_diff(a, b, name, eps=1e-3):
assert (a - b).abs().max() < eps, "name {}, diff {}".format(
name, (a - b).abs().max()
)
def test_softmax():
n_batch = 4
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024, 1536, 2048]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
for dtype in test_dtype:
x = torch.rand(
n_batch,
n_heads,
n_query,
last_dim,
dtype=dtype,
device=test_device,
)
mask = gen_attn_mask(
(
torch.rand(
n_batch,
1,
1,
last_dim,
dtype=dtype,
device=test_device,
)
> 0.2
).type(x.dtype),
-3e4,
)
bias = torch.rand(
n_batch, n_heads, n_query, last_dim, dtype=dtype, device=test_device
)
out_a1, out_b1, out_c1 = wrap_forward_backward(
normal_softmax, x, mask, bias
)
out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias)
check_diff(out_a1, out_a2, "output")
check_diff(out_b1, out_b2, "grad_input")
check_diff(out_c1, out_c2, "grad_bias")
def test_tri_softmax1():
n_batch = 2
n_groups = 32
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024, 1536, 2048]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
for dtype in test_dtype:
x = torch.rand(
n_batch,
n_groups,
n_heads,
n_query,
last_dim,
dtype=dtype,
device=test_device,
)
mask = gen_attn_mask(
(
torch.rand(
n_batch,
n_groups,
1,
1,
last_dim,
dtype=dtype,
device=test_device,
)
> 0.2
).type(x.dtype),
-3e4,
)
bias = torch.rand(
1, 1, n_heads, n_query, last_dim, dtype=dtype, device=test_device
)
out_a1, out_b1, out_c1 = wrap_forward_backward(
normal_softmax, x, mask, bias
)
out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias)
check_diff(out_a1, out_a2, "output")
check_diff(out_b1, out_b2, "grad_input")
check_diff(out_c1, out_c2, "grad_bias")
def test_tri_softmax2():
n_batch = 2
n_groups = 32
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024, 1536, 2048]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
for dtype in test_dtype:
x = torch.rand(
n_batch,
n_groups,
n_heads,
n_query,
last_dim,
dtype=dtype,
device=test_device,
)
mask = gen_attn_mask(
(
torch.rand(
n_batch,
n_groups,
n_heads,
1,
last_dim,
dtype=dtype,
device=test_device,
)
> 0.2
).type(x.dtype),
-3e4,
)
bias = torch.rand(
1, n_groups, n_heads, n_query, last_dim, dtype=dtype, device=test_device
)
out_a1, out_b1, out_c1 = wrap_forward_backward(
normal_softmax, x, mask, bias
)
out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias)
check_diff(out_a1, out_a2, "output")
check_diff(out_b1, out_b2, "grad_input")
check_diff(out_c1, out_c2, "grad_bias")
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import os
import sys
try:
from .version import __version__ # noqa
except ImportError:
version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_txt) as f:
__version__ = f.read().strip()
__all__ = ["pdb"]
# backwards compatibility to support `from unicore.X import Y`
from unicore.distributed import utils as distributed_utils
from unicore.logging import meters, metrics, progress_bar # noqa
sys.modules["unicore.distributed_utils"] = distributed_utils
sys.modules["unicore.meters"] = meters
sys.modules["unicore.metrics"] = metrics
sys.modules["unicore.progress_bar"] = progress_bar
import unicore.losses # noqa
import unicore.distributed # noqa
import unicore.models # noqa
import unicore.modules # noqa
import unicore.optim # noqa
import unicore.optim.lr_scheduler # noqa
import unicore.tasks # noqa
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ast
import collections
import logging
import os
import re
import shutil
import traceback
from typing import Any, Dict, Optional
import torch
logger = logging.getLogger(__name__)
# async ckp copy
def ckp_copy_fun(src, checkpoints, end_of_epoch, args):
has_copy = False
can_delete = args.tmp_save_dir != args.save_dir
for cp in checkpoints:
try:
if src != cp:
logger.info("copy {} to {}".format(src, cp))
has_copy = True
shutil.copyfile(src, cp)
except:
logger.info("copy failed, please copy it manaully")
try:
if can_delete and has_copy and os.path.lexists(src):
logger.info("removing temp file {} ...".format(src))
os.remove(src)
def remove_ckps(root_path):
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
root_path, pattern=r"checkpoint_\d+_(\d+)\.pt"
)
for old_chk in checkpoints[args.keep_interval_updates :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
logger.info("removed {}".format(old_chk))
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(root_path, pattern=r"checkpoint(\d+)\.pt")
for old_chk in checkpoints[args.keep_last_epochs :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
logger.info("removed {}".format(old_chk))
if args.keep_best_checkpoints > 0:
# only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths(
root_path,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
args.best_checkpoint_metric
),
)
if not args.maximize_best_checkpoint_metric:
checkpoints = checkpoints[::-1]
for old_chk in checkpoints[args.keep_best_checkpoints :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
logger.info("removed {}".format(old_chk))
remove_ckps(args.save_dir)
except:
logger.info("remove old ckps error")
logger.info("finished async ckp saving.")
def save_checkpoint(args, trainer, epoch_itr, val_loss, ckp_copy_thread, do_save=True):
from unicore import meters
# only one worker should attempt to create the required dir
if trainer.data_parallel_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
prev_best = getattr(save_checkpoint, "best", val_loss)
if val_loss is not None:
best_function = max if args.maximize_best_checkpoint_metric else min
save_checkpoint.best = best_function(val_loss, prev_best)
if args.no_save or not do_save:
return
if not trainer.should_save_checkpoint_on_current_rank:
return
write_timer = meters.StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
def is_better(a, b):
return a >= b if args.maximize_best_checkpoint_metric else a <= b
suffix = trainer.checkpoint_suffix
checkpoint_conds = collections.OrderedDict()
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0
)
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
not end_of_epoch
and args.save_interval_updates > 0
and updates % args.save_interval_updates == 0
)
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
not hasattr(save_checkpoint, "best")
or is_better(val_loss, save_checkpoint.best)
)
if val_loss is not None and args.keep_best_checkpoints > 0:
checkpoint_conds[
"checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss)
] = not hasattr(save_checkpoint, "best") or is_better(
val_loss, save_checkpoint.best
)
checkpoint_conds[
"checkpoint_last{}.pt".format(suffix)
] = not args.no_last_checkpoints
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best})
checkpoints = [
os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
tmp_checkpoints = [
os.path.join(args.tmp_save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
if len(checkpoints) > 0:
trainer.save_checkpoint(tmp_checkpoints[0], extra_state)
if ckp_copy_thread is not None:
ckp_copy_thread.apply_async(ckp_copy_fun, (tmp_checkpoints[0], checkpoints, end_of_epoch, args))
write_timer.stop()
logger.info(
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
tmp_checkpoints[0], epoch, updates, val_loss, write_timer.sum
)
)
def load_checkpoint(args, trainer, **passthrough_args):
"""
Load a checkpoint and restore the training iterator.
*passthrough_args* will be passed through to
``trainer.get_train_iterator``.
"""
reset_optimizer = args.reset_optimizer
reset_lr_scheduler = args.reset_lr_scheduler
optimizer_overrides = ast.literal_eval(args.optimizer_overrides)
reset_meters = args.reset_meters
reset_dataloader = args.reset_dataloader
if args.finetune_from_model is not None and (
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
):
raise ValueError(
"--finetune-from-model can not be set together with either --reset-optimizer"
" or reset_lr_scheduler or reset_meters or reset_dataloader"
)
suffix = trainer.checkpoint_suffix
if (
args.restore_file == "checkpoint_last.pt"
): # default value of restore_file is 'checkpoint_last.pt'
checkpoint_path = os.path.join(
args.save_dir, "checkpoint_last{}.pt".format(suffix)
)
first_launch = not os.path.exists(checkpoint_path)
if args.finetune_from_model is not None and first_launch:
# if there is no last checkpoint to restore, start the finetune from pretrained model
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
if os.path.exists(args.finetune_from_model):
checkpoint_path = args.finetune_from_model
reset_optimizer = True
reset_lr_scheduler = True
reset_meters = True
reset_dataloader = True
logger.info(
f"loading pretrained model from {checkpoint_path}: "
"optimizer, lr scheduler, meters, dataloader will be reset"
)
else:
raise ValueError(
f"--funetune-from-model {args.finetune_from_model} does not exist"
)
elif suffix is not None:
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
else:
checkpoint_path = args.restore_file
if args.restore_file != "checkpoint_last.pt" and args.finetune_from_model:
raise ValueError(
"--finetune-from-model and --restore-file (non-default value) "
"can not be specified together: " + str(args)
)
extra_state = trainer.load_checkpoint(
checkpoint_path,
reset_optimizer,
reset_lr_scheduler,
optimizer_overrides,
reset_meters=reset_meters,
)
if (
extra_state is not None
and "best" in extra_state
and not reset_optimizer
and not reset_meters
):
save_checkpoint.best = extra_state["best"]
if extra_state is not None and not reset_dataloader:
# restore iterator from checkpoint
itr_state = extra_state["train_iterator"]
epoch_itr = trainer.get_train_iterator(
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
)
epoch_itr.load_state_dict(itr_state)
else:
epoch_itr = trainer.get_train_iterator(
epoch=1, load_dataset=True, **passthrough_args
)
trainer.init_total_train_steps(epoch_itr)
trainer.lr_step(epoch_itr.epoch)
return extra_state, epoch_itr
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=True):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
There's currently no support for > 1 but < all processes loading the
checkpoint on each node.
"""
local_path = path
with open(local_path, "rb") as f:
state = torch.load(f, map_location=torch.device("cpu"))
if "args" in state and state["args"] is not None and arg_overrides is not None:
args = state["args"]
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
return state
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = float(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def torch_persistent_save(obj, filename):
# do atomic save
with open(filename + ".tmp", "wb") as f:
_torch_persistent_save(obj, f)
os.rename(filename + ".tmp", filename)
def _torch_persistent_save(obj, f):
if isinstance(f, str):
with open(f, "wb") as h:
torch_persistent_save(obj, h)
return
for i in range(3):
try:
return torch.save(obj, f)
except Exception:
if i == 2:
logger.error(traceback.format_exc())
def verify_checkpoint_directory(save_dir: str) -> None:
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
temp_file_path = os.path.join(save_dir, "dummy")
try:
with open(temp_file_path, "w"):
pass
except OSError as e:
logger.warning(
"Unable to access checkpoint save directory: {}".format(save_dir)
)
raise e
else:
os.remove(temp_file_path)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from .unicore_dataset import UnicoreDataset
from .base_wrapper_dataset import BaseWrapperDataset
from .append_token_dataset import AppendTokenDataset
from .dictionary import Dictionary
from .lru_cache_dataset import LRUCacheDataset
from .mask_tokens_dataset import MaskTokensDataset
from .bert_tokenize_dataset import BertTokenizeDataset
from .tokenize_dataset import TokenizeDataset
from .nested_dictionary_dataset import NestedDictionaryDataset
from .numel_dataset import NumelDataset
from .num_samples_dataset import NumSamplesDataset
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset, RightPadDataset2D
from .prepend_token_dataset import PrependTokenDataset
from .raw_dataset import RawLabelDataset, RawArrayDataset, RawNumpyDataset
from .lmdb_dataset import LMDBDataset
from .sort_dataset import SortDataset, EpochShuffleDataset
from .from_numpy_dataset import FromNumpyDataset
from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)
__all__ = []
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from functools import lru_cache
from . import BaseWrapperDataset
class AppendTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
@lru_cache(maxsize=16)
def __getitem__(self, idx):
item = self.dataset[idx]
if self.token is not None:
item = torch.cat([item, torch.full_like(item[0], self.token).unsqueeze(0)], dim=0)
return item
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.utils.data.dataloader import default_collate
from . import UnicoreDataset
class BaseWrapperDataset(UnicoreDataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if hasattr(self.dataset, "collater"):
return self.dataset.collater(samples)
else:
return default_collate(samples)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def attr(self, attr: str, index: int):
return self.dataset.attr(attr, index)
def prefetch(self, indices):
self.dataset.prefetch(indices)
def batch_by_size(
self,
indices,
batch_size=None,
required_batch_size_multiple=1,
):
return self.dataset.batch_by_size(
indices,
batch_size=batch_size,
required_batch_size_multiple=required_batch_size_multiple,
)
@property
def can_reuse_epoch_itr_across_epochs(self):
return self.dataset.can_reuse_epoch_itr_across_epochs
def set_epoch(self, epoch):
super().set_epoch(epoch)
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
import numpy as np
import torch
from tokenizers import BertWordPieceTokenizer
from . import BaseWrapperDataset, LRUCacheDataset
class BertTokenizeDataset(BaseWrapperDataset):
def __init__(
self,
dataset: torch.utils.data.Dataset,
dict_path: str,
max_seq_len: int=512,
):
self.dataset = dataset
self.tokenizer = BertWordPieceTokenizer(dict_path, lowercase=True)
self.max_seq_len = max_seq_len
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the noise changes, not item sizes
def __getitem__(self, index: int):
raw_str = self.dataset[index]
raw_str = raw_str.replace('<unk>', '[UNK]')
output = self.tokenizer.encode(raw_str)
ret = torch.Tensor(output.ids).long()
if ret.size(0) > self.max_seq_len:
ret = ret[:self.max_seq_len]
return ret
\ No newline at end of file
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import numpy as np
import torch
logger = logging.getLogger(__name__)
def collate_tokens(
values,
pad_idx,
left_pad=False,
pad_to_length=None,
pad_to_multiple=1,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
def collate_tokens_2d(
values,
pad_idx,
left_pad=False,
pad_to_length=None,
pad_to_multiple=1,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
res = values[0].new(len(values), size, size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v):, size - len(v):] if left_pad else res[i][:len(v), :len(v)])
return res
def collate_dict(
values,
dim=0,
):
if len(values) <= 0:
return values
ret = {}
keys = values[0].keys()
for key in keys:
ret[key] = torch.stack([v[key] for v in values], dim=dim)
return ret
def str_hash(text:str):
hash=0
for ch in text:
hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF
return hash
@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds, key=None):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
def check_seed(s):
assert type(s) == int or type(s) == np.int32 or type(s) == np.int64
check_seed(seed)
if len(addl_seeds) > 0:
for s in addl_seeds:
check_seed(s)
seed = int(hash((seed, *addl_seeds)) % 1e8)
if key is not None:
seed = int(hash((seed, str_hash(key))) % 1e8)
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
def batch_by_size(
indices,
batch_size=None,
required_batch_size_multiple=1,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
batch_size (int, optional): max number of sentences in each
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be less than N or a multiple of N (default: 1).
"""
batch_size = batch_size if batch_size is not None else 1
bsz_mult = required_batch_size_multiple
step = ((batch_size + bsz_mult - 1) // bsz_mult) * bsz_mult
if not isinstance(indices, np.ndarray):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
num_batches = (len(indices) + step - 1) // step
steps = np.arange(num_batches - 1) + 1
steps *= step
batch_indices = np.split(indices, steps)
assert len(batch_indices) == num_batches
# validation or test data size is smaller than a mini-batch size in some downstream tasks.
assert batch_indices[0].shape[0] <= step
return batch_indices
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Dictionary:
"""A mapping from symbols to consecutive integers"""
def __init__(
self,
*, # begin keyword-only arguments
bos="[CLS]",
pad="[PAD]",
eos="[SEP]",
unk="[UNK]",
extra_special_symbols=None,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.specials = set()
self.specials.add(bos)
self.specials.add(unk)
self.specials.add(pad)
self.specials.add(eos)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
def vec_index(self, a):
return np.vectorize(self.index)(a)
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices:
return self.indices[sym]
return self.indices[self.unk_word]
def special_index(self):
return [self.index(x) for x in self.specials]
def add_symbol(self, word, n=1, overwrite=False, is_special=False):
"""Adds a word to the dictionary"""
if is_special:
self.specials.add(word)
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def bos(self):
"""Helper to get index of beginning-of-sentence symbol"""
return self.index(self.bos_word)
def pad(self):
"""Helper to get index of pad symbol"""
return self.index(self.pad_word)
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.index(self.eos_word)
def unk(self):
"""Helper to get index of unk symbol"""
return self.index(self.unk_word)
@classmethod
def load(cls, f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f)
return d
def add_from_file(self, f):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if isinstance(f, str):
try:
with open(f, "r", encoding="utf-8") as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception(
"Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f)
)
return
lines = f.readlines()
for line_idx, line in enumerate(lines):
try:
splits = line.rstrip().rsplit(" ", 1)
line = splits[0]
field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx)
if field == "#overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
word = line
if word in self and not overwrite:
logger.info(
"Duplicate word found when loading Dictionary: '{}', index is {}.".format(word, self.indices[word])
)
else:
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError(
"Incorrect dictionary format, expected '<token> <cnt> [flags]'"
)
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from functools import lru_cache
from . import BaseWrapperDataset
class FromNumpyDataset(BaseWrapperDataset):
def __init__(self, dataset):
super().__init__(dataset)
@lru_cache(maxsize=16)
def __getitem__(self, idx):
return torch.from_numpy(self.dataset[idx])
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import logging
import math
import operator
import os
import queue
import time
from threading import Thread
import numpy as np
import torch
from unicore.data import data_utils
logger = logging.getLogger(__name__)
# Object used by _background_consumer to signal the source is exhausted
# to the main thread.
_sentinel = object()
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count.
Args:
iterable (iterable): iterable to wrap
start (int): starting iteration count. Note that this doesn't
actually advance the iterator.
total (int): override the iterator length returned by
``__len__``. This can be used to truncate *iterator*.
Attributes:
n (int): number of elements consumed from this iterator
"""
def __init__(self, iterable, start=None, total=None):
self.iterable = iterable
self.itr = iter(self)
if start is None:
self.n = getattr(iterable, "n", 0)
else:
self.n = start
if total is None:
self.total = self.n + len(iterable)
else:
self.total = total
def __len__(self):
return self.total
def __iter__(self):
for x in self.iterable:
if self.n >= self.total:
raise RuntimeError(
"Mismatch between actual and expected iterable length. "
"This may be caused by resuming training from a checkpoint using "
"a different number of GPUs, in which case you can try the "
"--reset-dataloader option. Alternatively you may have a train or "
"validation set that is smaller than the number of GPUs. If none "
"of these apply, please report this to the unicore developers."
)
self.n += 1
yield x
def __next__(self):
return next(self.itr)
def has_next(self):
"""Whether the iterator has been exhausted."""
return self.n < len(self)
def skip(self, num_to_skip):
"""Fast-forward the iterator by skipping *num_to_skip* elements."""
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
def take(self, n):
"""
Truncates the iterator to n elements at most.
"""
self.total = min(self.total, n)
# Propagate this change to the underlying iterator
# Only take after what we have already consumed (i.e. after restarting
# from checkpoint mid epoch, we have to subtract self.n which is the
# starting point)
#
# This to maintain the invariant self.total = self.n + len(iterable),
# before calling __next__ or __iter__
propagated_take = max(n - self.n, 0)
if hasattr(self.iterable, "take"):
self.iterable.take(propagated_take)
else:
self.iterable = itertools.islice(self.iterable, propagated_take)
class EpochBatchIterating(object):
def __len__(self) -> int:
raise NotImplementedError
@property
def next_epoch_idx(self):
raise NotImplementedError
def next_epoch_itr(
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator (default: True).
fix_batches_to_gpus (bool, optional): ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching (default: False).
set_dataset_epoch (bool, optional): update the wrapped Dataset with
the new epoch number (default: True).
"""
raise NotImplementedError
def end_of_epoch(self) -> bool:
"""Returns whether the most recent epoch iterator has been exhausted"""
raise NotImplementedError
@property
def iterations_in_epoch(self) -> int:
"""The number of consumed batches in the current epoch."""
raise NotImplementedError
def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
raise NotImplementedError
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
raise NotImplementedError
@property
def first_batch(self):
return "DUMMY"
class EpochBatchIterator(EpochBatchIterating):
"""A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
Compared to :class:`torch.utils.data.DataLoader`, this iterator:
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the *num_shards* and *shard_id* arguments
Args:
dataset (~torch.utils.data.Dataset): dataset from which to load the data
collate_fn (callable): merges a list of samples to form a mini-batch
batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of
indices, or a callable to create such an iterator (~torch.utils.data.Sampler).
A callable batch_sampler will be called for each epoch to enable per epoch dynamic
batch iterators defined by this callable batch_sampler.
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 1).
buffer_size (int, optional): the number of batches to keep ready in the
queue. Helps speeding up dataloading. When buffer_size is zero, the
default torch.utils.data.DataLoader preloading is used.
timeout (int, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative (default: ``0``).
disable_shuffling (bool, optional): force disable shuffling
(default: ``False``).
"""
def __init__(
self,
dataset,
collate_fn,
batch_sampler,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
buffer_size=0,
timeout=0,
disable_shuffling=False,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.collate_fn = collate_fn
self.batch_sampler = batch_sampler
self._frozen_batches = (
tuple(batch_sampler) if not callable(batch_sampler) else None
)
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.num_workers = num_workers
# This upper limit here is to prevent people from abusing this feature
# in a shared computing environment.
self.buffer_size = min(buffer_size, 32)
self.timeout = timeout
self.disable_shuffling = disable_shuffling
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
self.shuffle = not disable_shuffling
self._cur_epoch_itr = None
self._next_epoch_itr = None
self._supports_prefetch = getattr(dataset, "supports_prefetch", False)
@property
def frozen_batches(self):
if self._frozen_batches is None:
self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch))
return self._frozen_batches
@property
def first_batch(self):
if len(self.frozen_batches) == 0:
raise Exception(
"The dataset is empty. This could indicate "
"that all elements in the dataset have been skipped. "
"Try increasing the max number of allowed tokens or using "
"a larger dataset."
)
if getattr(self.dataset, "supports_fetch_outside_dataloader", True):
return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]])
else:
return "DUMMY"
def __len__(self):
return int(math.ceil(len(self.frozen_batches) / float(self.num_shards)))
@property
def n(self):
return self.iterations_in_epoch
@property
def next_epoch_idx(self):
"""Return the epoch index after *next_epoch_itr* is called."""
if self._next_epoch_itr is not None:
return self.epoch
elif self._cur_epoch_itr is not None and self.end_of_epoch():
return self.epoch + 1
else:
return self.epoch
def next_epoch_itr(
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator (default: True).
fix_batches_to_gpus (bool, optional): ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching (default: False).
set_dataset_epoch (bool, optional): update the wrapped Dataset with
the new epoch number (default: True).
"""
if self.disable_shuffling:
shuffle = False
self.epoch = self.next_epoch_idx
if set_dataset_epoch and hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(self.epoch)
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
if callable(self.batch_sampler):
# reset _frozen_batches to refresh the next epoch
self._frozen_batches = None
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch,
shuffle,
fix_batches_to_gpus=fix_batches_to_gpus,
)
self.shuffle = shuffle
return self._cur_epoch_itr
def end_of_epoch(self) -> bool:
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
"""The number of consumed batches in the current epoch."""
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.n
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.n
return 0
def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
if self.end_of_epoch():
epoch = self.epoch + 1
iter_in_epoch = 0
else:
epoch = self.epoch
iter_in_epoch = self.iterations_in_epoch
return {
"epoch": epoch,
"iterations_in_epoch": iter_in_epoch,
"shuffle": self.shuffle,
"len": len(self),
}
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
self.epoch = state_dict["epoch"]
itr_pos = state_dict.get("iterations_in_epoch", 0)
if itr_pos > 0:
if "len" in state_dict and state_dict["len"] != len(self):
old_itr_pos = itr_pos
itr_pos = int(itr_pos * len(self) / state_dict["len"])
logger.info(
"Iterator size is changed. it is possible due to the change of update_freq/num_gpu. The itr_pos is change from {} to {} for consistency".format(old_itr_pos, itr_pos)
)
# fast-forward epoch iterator
self._next_epoch_itr = self._get_iterator_for_epoch(
self.epoch,
shuffle=state_dict.get("shuffle", True),
offset=itr_pos,
)
if self._next_epoch_itr is None:
raise RuntimeError(
"Cannot resume training due to dataloader mismatch. You can relaunch "
"training with `--reset-dataloader` and it should work."
)
else:
self._next_epoch_itr = None
def _get_iterator_for_epoch(
self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
):
def shuffle_batches(batches, seed):
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
return batches
if self._supports_prefetch:
batches = self.frozen_batches
if shuffle and not fix_batches_to_gpus:
batches = shuffle_batches(list(batches), self.seed + epoch)
batches = list(
ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
)
self.dataset.prefetch([i for s in batches for i in s])
if shuffle and fix_batches_to_gpus:
batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
else:
if shuffle:
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
else:
batches = self.frozen_batches
batches = list(
ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
)
if offset > 0 and offset >= len(batches):
return None
if self.num_workers > 0:
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
# Create data loader
itr = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches[offset:],
num_workers=self.num_workers,
timeout=self.timeout,
)
# Wrap with a BufferedIterator if needed
if self.buffer_size > 0:
itr = BufferedIterator(self.buffer_size, itr)
# Wrap with CountingIterator
itr = CountingIterator(itr, start=offset)
return itr
class GroupedIterator(CountingIterator):
"""Wrapper around an iterable that returns groups (chunks) of items.
Args:
iterable (iterable): iterable to wrap
chunk_size (int): size of each chunk
Attributes:
n (int): number of elements consumed from this iterator
"""
def __init__(self, iterable, chunk_size):
itr = _chunk_iterator(iterable, chunk_size)
super().__init__(
itr,
start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))),
total=int(math.ceil(len(iterable) / float(chunk_size))),
)
self.chunk_size = chunk_size
def _chunk_iterator(itr, chunk_size):
chunk = []
for x in itr:
chunk.append(x)
if len(chunk) == chunk_size:
yield chunk
chunk = []
if len(chunk) > 0:
yield chunk
class ShardedIterator(CountingIterator):
"""A sharded wrapper around an iterable, padded to length.
Args:
iterable (iterable): iterable to wrap
num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards* (default: None).
Attributes:
n (int): number of elements consumed from this iterator
"""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError("shard_id must be between 0 and num_shards")
sharded_len = int(math.ceil(len(iterable) / float(num_shards)))
itr = map(
operator.itemgetter(1),
itertools.zip_longest(
range(sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
),
)
super().__init__(
itr,
start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))),
total=sharded_len,
)
class BackgroundConsumer(Thread):
def __init__(self, queue, source, max_len):
Thread.__init__(self)
self._queue = queue
self._source = source
self._max_len = max_len
self.count = 0
def run(self):
try:
for item in self._source:
self._queue.put(item)
# Stop if we reached the maximum length
self.count += 1
if self._max_len is not None and self.count >= self._max_len:
break
# Signal the consumer we are done.
self._queue.put(_sentinel)
except Exception as e:
self._queue.put(e)
class BufferedIterator(object):
def __init__(self, size, iterable):
self._queue = queue.Queue(size)
self._iterable = iterable
self._consumer = None
self.start_time = time.time()
self.warning_time = None
self.total = len(iterable)
def _create_consumer(self):
self._consumer = BackgroundConsumer(
self._queue,
self._iterable,
self.total,
)
self._consumer.daemon = True
self._consumer.start()
def __iter__(self):
return self
def __len__(self):
return self.total
def take(self, n):
self.total = min(self.total, n)
# Propagate this change to the underlying iterator
if hasattr(self._iterable, "take"):
self._iterable.take(n)
def __next__(self):
# Create consumer if not created yet
if self._consumer is None:
self._create_consumer()
# Notify the user if there is a data loading bottleneck
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
if time.time() - self.start_time > 5 * 60:
if (
self.warning_time is None
or time.time() - self.warning_time > 15 * 60
):
logger.debug(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
"number of workers (--num-workers) may help."
)
self.warning_time = time.time()
# Get next example
item = self._queue.get(True)
if isinstance(item, Exception):
raise item
if item is _sentinel:
raise StopIteration()
return item
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import lmdb
import os
import pickle
import torch
import numpy as np
import collections
from functools import lru_cache
from . import data_utils
import logging
logger = logging.getLogger(__name__)
class LMDBDataset:
def __init__(self, db_path):
self.db_path = db_path
assert os.path.isfile(self.db_path), "{} not found".format(
self.db_path
)
env = self.connect_db(self.db_path)
with env.begin() as txn:
self._keys = list(txn.cursor().iternext(values=False))
def connect_db(self, lmdb_path, save_to_self=False):
env = lmdb.open(
lmdb_path,
subdir=False,
readonly=True,
lock=False,
readahead=False,
meminit=False,
max_readers=256,
)
if not save_to_self:
return env
else:
self.env = env
def __len__(self):
return len(self._keys)
@lru_cache(maxsize=16)
def __getitem__(self, idx):
if not hasattr(self, 'env'):
self.connect_db(self.db_path, save_to_self=True)
datapoint_pickled = self.env.begin().get(self._keys[idx])
data = pickle.loads(datapoint_pickled)
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment