Commit 3b355d3f authored by yuguo960516's avatar yuguo960516
Browse files

gpt2

parent fd158e88
Pipeline #143 canceled with stages
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT Style dataset."""
import numpy as np
import oneflow as flow
from libai.data.structures import DistTensorData, Instance
from ..data_utils import create_masked_lm_predictions, get_samples_mapping
class BertDataset(flow.utils.data.Dataset):
"""Dataset containing sentence pairs for BERT training.
Each index corresponds to a randomly generated sentence pair.
Args:
name: Name of dataset for clarification.
tokenizer: Tokenizer to use.
data_prefix: Path to the training dataset.
indexed_dataset: Indexed dataset to use.
max_seq_length: Maximum length of the sequence. All values are padded to
this length. Defaults to 512.
mask_lm_prob: Probability to mask tokens. Defaults to 0.15.
short_seq_prob: Probability of producing a short sequence. Defaults to 0.0.
max_predictions_per_seq: Maximum number of mask tokens in each sentence. Defaults to None.
seed: Seed for random number generator for reproducibility. Defaults to 1234.
binary_head: Specifies whether the underlying dataset
generates a pair of blocks along with a sentence_target or not.
Setting it to True assumes that the underlying dataset generates a
label for the pair of sentences which is surfaced as
sentence_target. Defaults to True.
"""
def __init__(
self,
name,
tokenizer,
indexed_dataset,
data_prefix,
max_num_samples,
mask_lm_prob,
max_seq_length,
short_seq_prob=0.0,
seed=1234,
binary_head=True,
masking_style="bert",
):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = mask_lm_prob
self.max_seq_length = max_seq_length
self.binary_head = binary_head
self.masking_style = masking_style
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(
self.indexed_dataset,
data_prefix,
None,
max_num_samples,
self.max_seq_length - 3, # account for added tokens
short_seq_prob,
self.seed,
self.name,
self.binary_head,
)
# Vocab stuff.
self.tokenizer = tokenizer
self.vocab_id_list = list(tokenizer.get_vocab().values())
self.vocab_id_to_token_dict = {v: k for k, v in tokenizer.get_vocab().items()}
self.cls_id = tokenizer.cls_token_id
self.sep_id = tokenizer.sep_token_id
self.mask_id = tokenizer.mask_token_id
self.pad_id = tokenizer.pad_token_id
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, seq_length = self.samples_mapping[idx]
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
# We % 2**32 since numpy requires the seed to be between 0 and 2**32 - 1
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2 ** 32))
return build_training_sample(
self.tokenizer,
sample,
seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id,
self.sep_id,
self.mask_id,
self.pad_id,
self.masked_lm_prob,
np_rng,
self.binary_head,
masking_style=self.masking_style,
)
def build_training_sample(
tokenizer,
sample,
target_seq_length,
max_seq_length,
vocab_id_list,
vocab_id_to_token_dict,
cls_id,
sep_id,
mask_id,
pad_id,
masked_lm_prob,
np_rng,
binary_head,
masking_style="bert",
):
"""Build training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the upper bound whereas the numpy one is exclusive.
"""
if binary_head:
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng)
# Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id)
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
tokenizer,
tokens,
vocab_id_list,
vocab_id_to_token_dict,
masked_lm_prob,
cls_id,
sep_id,
mask_id,
max_predictions_per_seq,
np_rng,
masking_style=masking_style,
)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy(
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
)
train_sample = Instance(
input_ids=DistTensorData(flow.tensor(tokens_np)),
attention_mask=DistTensorData(flow.tensor(padding_mask_np)),
tokentype_ids=DistTensorData(flow.tensor(tokentypes_np)),
ns_labels=DistTensorData(
flow.tensor(int(is_next_random), dtype=flow.long), placement_idx=-1
),
lm_labels=DistTensorData(flow.tensor(labels_np), placement_idx=-1),
loss_mask=DistTensorData(flow.tensor(loss_mask_np), placement_idx=-1),
)
return train_sample
def pad_and_convert_to_numpy(
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.bool)
# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.bool)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments."""
# Number of sentences in the sample.
n_sentences = len(sample)
# Make sure we always have two sentences.
assert n_sentences > 1, "make sure each sample has at least two sentences."
# First part:
# `a_end` is how many sentences go into the `A`.
a_end = 1
if n_sentences >= 3:
# Note that randin in numpy is exclusive.
a_end = np_rng.randint(1, n_sentences)
tokens_a = []
for j in range(a_end):
tokens_a.extend(sample[j])
# Second part:
tokens_b = []
for j in range(a_end, n_sentences):
tokens_b.extend(sample[j])
# Random next:
is_next_random = False
if np_rng.random() < 0.5:
is_next_random = True
tokens_a, tokens_b = tokens_b, tokens_a
return tokens_a, tokens_b, is_next_random
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length."""
assert len_a > 0
if len_a + len_b <= max_num_tokens:
return False
while len_a + len_b > max_num_tokens:
if len_a > len_b:
len_a -= 1
tokens = tokens_a
else:
len_b -= 1
tokens = tokens_b
if np_rng.random() < 0.5:
del tokens[0]
else:
tokens.pop()
return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
tokens = []
tokentypes = []
# [CLS].
tokens.append(cls_id)
tokentypes.append(0)
# Segment A.
for token in tokens_a:
tokens.append(token)
tokentypes.append(0)
# [SEP].
tokens.append(sep_id)
tokentypes.append(0)
# Segment B.
for token in tokens_b:
tokens.append(token)
tokentypes.append(1)
if tokens_b:
# [SEP].
tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
import oneflow as flow
from flowvision import datasets
from libai.data.structures import DistTensorData, Instance
class CIFAR10Dataset(datasets.CIFAR10):
r"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset in LiBai.
Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If the dataset is already downloaded, it will not be
downloaded again.
"""
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
download: bool = False,
**kwargs
):
super(CIFAR10Dataset, self).__init__(
root=root, train=train, transform=transform, download=download, **kwargs
)
def __getitem__(self, index: int):
img, target = super().__getitem__(index)
data_sample = Instance(
images=DistTensorData(img, placement_idx=0),
labels=DistTensorData(flow.tensor(target, dtype=flow.long), placement_idx=-1),
)
return data_sample
class CIFAR100Dataset(datasets.CIFAR100):
r"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset in LiBai.
Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If the dataset is already downloaded, it will not be
downloaded again.
dataset_name (str, optional): Name for the dataset as an identifier. E.g, ``cifar100``
"""
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
download: bool = False,
**kwargs
):
super(CIFAR100Dataset, self).__init__(
root=root, train=train, transform=transform, download=download, **kwargs
)
def __getitem__(self, index: int):
img, target = super().__getitem__(index)
data_sample = Instance(
images=DistTensorData(img, placement_idx=0),
labels=DistTensorData(flow.tensor(target, dtype=flow.long), placement_idx=-1),
)
return data_sample
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT style dataset."""
import logging
import os
import time
import numpy as np
import oneflow as flow
from libai.data.structures import DistTensorData, Instance
from libai.utils import distributed as dist
logger = logging.getLogger(__name__)
class GPT2Dataset(flow.utils.data.Dataset):
def __init__(
self,
name,
tokenizer,
data_prefix,
indexed_dataset,
max_num_samples,
max_seq_length,
seed=1234,
):
self.name = name
self.tokenizer = tokenizer
self.indexed_dataset = indexed_dataset
documents = np.arange(start=0, stop=indexed_dataset.sizes.shape[0], step=1, dtype=np.int32)
# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
self.name,
data_prefix,
documents,
self.indexed_dataset.sizes,
max_num_samples,
max_seq_length,
seed,
)
def __len__(self):
# -1 is due to data structure used to retrieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1
def __getitem__(self, idx):
# Get the shuffled index.
idx = self.shuffle_idx[idx]
# Start and end documents and offsets.
doc_index_f = self.sample_idx[idx][0]
doc_index_l = self.sample_idx[idx + 1][0]
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1]
# If we are within the same document, just extract the chunk.
if doc_index_f == doc_index_l:
sample = self.indexed_dataset.get(
self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1
)
else:
# Otherwise, get the rest of the initial document.
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l):
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
sample_list.append(
self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)
)
sample = np.concatenate(sample_list)
input_ids = flow.tensor(np.array(sample[:-1], dtype=np.int64))
lm_labels = flow.tensor(np.array(sample[1:], dtype=np.int64))
sample = Instance(
input_ids=DistTensorData(input_ids),
labels=DistTensorData(lm_labels, placement_idx=-1),
)
return sample
def _build_index_mappings(name, data_prefix, documents, sizes, num_samples, seq_length, seed):
"""Build doc-idx, sample-idx, and shuffle-idx.
doc-idx: is an array (ordered) of documents to be used in training.
sample-idx: is the start document index and document offset for each
training sample.
shuffle-idx: maps the sample index into a random index into sample-idx.
"""
# Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(documents, sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state
np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings.
_filename = data_prefix
_filename += "_{}_indexmap".format(name)
_filename += "_{}ns".format(num_samples)
_filename += "_{}sl".format(seq_length)
_filename += "_{}s".format(seed)
doc_idx_filename = _filename + "_doc_idx.npy"
sample_idx_filename = _filename + "_sample_idx.npy"
shuffle_idx_filename = _filename + "_shuffle_idx.npy"
# Build the indexed mapping if not exist.
# NOTE: use `get_local_rank() == 0` to promise samples will be build in each node.
if flow.env.get_local_rank() == 0:
if (
(not os.path.isfile(doc_idx_filename))
or (not os.path.isfile(sample_idx_filename))
or (not os.path.isfile(shuffle_idx_filename))
):
logger.info(
" > WARNING: could not find index map files, building " "the indices on rank 0 ..."
)
# For the last epoch, decide whether include the entire epoch
# in the global shuffle or not.
# If we need only one epoch, then separating last epoch does
# not mean anything.
if num_epochs == 1:
separate_last_epoch = False
logger.info(" > only one epoch required, setting " "separate_last_epoch to False")
else:
# Get the number of samples for the last epoch
num_samples_from_epochs_minus_one = (
(num_epochs - 1) * tokens_per_epoch - 1
) // seq_length
last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one
assert (
last_epoch_num_samples >= 0
), "last epoch number of samples should be non-negative."
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
assert last_epoch_num_samples < (
num_samples_per_epoch + 1
), "last epoch number of samples exceeded max value."
# If we have less than 80% of the samples for the last epoch,
# separate out the epoch and treat it differently.
# Note: the 80% number is just based on common sense and can
# be adjusted if needed.
separate_last_epoch = last_epoch_num_samples < int(0.80 * num_samples_per_epoch)
if separate_last_epoch:
string = (
" > last epoch number of samples ({}) is smaller "
"than 80% of number of samples per epoch ({}), "
"setting separate_last_epoch to True"
)
else:
string = (
" > last epoch number of samples ({}) is larger "
"than 80% of number of samples per epoch ({}), "
"setting separate_last_epoch to False"
)
logger.info(string.format(last_epoch_num_samples, num_samples_per_epoch))
# doc-idx.
logger.info("start to build and save doc-idx mapping ...")
start_time = time.time()
doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
logger.info(
" > elapsed time to build and save doc-idx mapping "
"(seconds): {:4f}".format(time.time() - start_time)
)
# sample-idx.
logger.info("start to build and save sample-idx mapping ...")
start_time = time.time()
# Use C++ implementation for speed.
# First compile and then import.
from libai.data.data_utils import helpers
assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
)
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
logger.info(
" > elapsed time to build and save sample-idx mapping "
"(seconds): {:4f}".format(time.time() - start_time)
)
# shuffle-idx.
start_time = time.time()
# -1 is due to data structure used to retrieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
if separate_last_epoch:
num_samples_ = num_samples_from_epochs_minus_one
else:
num_samples_ = sample_idx.shape[0] - 1
shuffle_idx = _build_shuffle_idx(num_samples_, sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
logger.info(
" > elapsed time to build and save shuffle-idx mapping"
" (seconds): {:4f}".format(time.time() - start_time)
)
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
dist.synchronize()
# Load mappings.
start_time = time.time()
logger.info(" > loading doc-idx mapping from {}".format(doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r")
logger.info(" > loading sample-idx mapping from {}".format(sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r")
logger.info(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r")
logger.info(" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time))
logger.info(" total number of samples: {}".format(sample_idx.shape[0]))
logger.info(" total number of epochs: {}".format(num_epochs))
return doc_idx, sample_idx, shuffle_idx
def _num_tokens(documents, sizes):
"""Total number of tokens in the dataset."""
return np.sum(sizes[documents])
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
"""Based on number of samples and sequence length, calculate how many
epochs will be needed."""
num_epochs = 0
total_tokens = 0
while True:
num_epochs += 1
total_tokens += tokens_per_epoch
# -1 is because we need to retrieve seq_length + 1 token each time
# but the last token will overlap with the first token of the next
# sample except for the last sample.
if ((total_tokens - 1) // seq_length) >= num_samples:
return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
"""Build an array with length = number-of-epochs * number-of-documents.
Each index is mapped to a corresponding document."""
if not separate_last_epoch or num_epochs == 1:
doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1]
doc_idx[:] = documents
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
np_rng.shuffle(doc_idx)
return doc_idx
doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False)
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
return np.concatenate((doc_idx_first, doc_idx_last))
def _build_shuffle_idx(num_samples, total_size, np_rng):
"""Build the range [0, size) and shuffle."""
logger.info(
" > building shuffle index with split [0, {}) and [{}, {}) "
"...".format(num_samples, num_samples, total_size)
)
dtype_ = np.uint32
if total_size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_first)
if num_samples == total_size:
return shuffle_idx_first
shuffle_idx_last = np.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_last)
return np.concatenate((shuffle_idx_first, shuffle_idx_last))
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable, Optional
import oneflow as flow
from flowvision import datasets
from libai.data.structures import DistTensorData, Instance
class ImageNetDataset(datasets.ImageFolder):
r"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset in LiBai.
Args:
root (string): Root directory of the ImageNet Dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
def __init__(
self, root: str, train: bool = True, transform: Optional[Callable] = None, **kwargs
):
prefix = "train" if train else "val"
root = os.path.join(root, prefix)
super(ImageNetDataset, self).__init__(root=root, transform=transform, **kwargs)
def __getitem__(self, index: int):
sample, target = super().__getitem__(index)
data_sample = Instance(
images=DistTensorData(sample, placement_idx=0),
labels=DistTensorData(flow.tensor(target, dtype=flow.long), placement_idx=-1),
)
return data_sample
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
import oneflow as flow
from flowvision import datasets
from libai.data.structures import DistTensorData, Instance
class MNISTDataset(datasets.MNIST):
r"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset in LiBai.
Args:
root (string): Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If the dataset is already downloaded, it will not be
downloaded again.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
dataset_name (str, optional): Name for the dataset as an identifier. E.g, ``mnist``
"""
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
download: bool = False,
**kwargs
):
super(MNISTDataset, self).__init__(
root=root, train=train, transform=transform, download=download, **kwargs
)
def __getitem__(self, index: int):
img, target = super().__getitem__(index)
data_sample = Instance(
images=DistTensorData(img, placement_idx=0),
labels=DistTensorData(flow.tensor(target, dtype=flow.long), placement_idx=-1),
)
return data_sample
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Roberta Style dataset."""
import numpy as np
import oneflow as flow
from libai.data.structures import DistTensorData, Instance
from ..data_utils import create_masked_lm_predictions, get_samples_mapping
from .bert_dataset import pad_and_convert_to_numpy
class RobertaDataset(flow.utils.data.Dataset):
"""Dataset containing sentence for RoBERTa training.
Each index corresponds to a randomly selected sentence.
Args:
name: Name of dataset for clarification.
tokenizer: Tokenizer to use.
data_prefix: Path to the training dataset.
indexed_dataset: Indexed dataset to use.
max_seq_length: Maximum length of the sequence. All values are padded to
this length. Defaults to 512.
mask_lm_prob: Probability to mask tokens. Defaults to 0.15.
short_seq_prob: Probability of producing a short sequence. Defaults to 0.0.
max_predictions_per_seq: Maximum number of mask tokens in each sentence. Defaults to None.
seed: Seed for random number generator for reproducibility. Defaults to 1234.
"""
def __init__(
self,
name,
tokenizer,
indexed_dataset,
data_prefix,
max_num_samples,
mask_lm_prob,
max_seq_length,
short_seq_prob=0.0,
seed=1234,
masking_style="bert",
):
super().__init__()
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = mask_lm_prob
self.max_seq_length = max_seq_length
self.masking_style = masking_style
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(
self.indexed_dataset,
data_prefix,
None,
max_num_samples,
self.max_seq_length - 2, # account for added tokens
short_seq_prob,
self.seed,
self.name,
binary_head=False,
)
# Vocab stuff.
self.tokenizer = tokenizer
self.vocab_id_list = list(tokenizer.get_vocab().values())
self.vocab_id_to_token_dict = {v: k for k, v in tokenizer.get_vocab().items()}
self.cls_id = tokenizer.cls_token_id
self.sep_id = tokenizer.sep_token_id
self.mask_id = tokenizer.mask_token_id
self.pad_id = tokenizer.pad_token_id
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, seq_length = self.samples_mapping[idx]
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
# We % 2**32 since numpy requires the seed to be between 0 and 2**32 - 1
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2 ** 32))
return build_training_sample(
self.tokenizer,
sample,
seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id,
self.sep_id,
self.mask_id,
self.pad_id,
self.masked_lm_prob,
np_rng,
masking_style=self.masking_style,
)
def build_training_sample(
tokenizer,
sample,
target_seq_length,
max_seq_length,
vocab_id_list,
vocab_id_to_token_dict,
cls_id,
sep_id,
mask_id,
pad_id,
masked_lm_prob,
np_rng,
masking_style="bert",
):
"""Build training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the upper bound whereas the numpy one is exclusive.
"""
assert target_seq_length <= max_seq_length
tokens = []
for j in range(len(sample)):
tokens.extend(sample[j])
max_num_tokens = target_seq_length
truncate_segments(tokens, len(tokens), max_num_tokens, np_rng)
# create tokens and tokentypes
tokens, tokentypes = create_tokens_and_tokentypes(tokens, cls_id, sep_id)
# Masking
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
tokenizer,
tokens,
vocab_id_list,
vocab_id_to_token_dict,
masked_lm_prob,
cls_id,
sep_id,
mask_id,
max_predictions_per_seq,
np_rng,
masking_style=masking_style,
)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy(
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
)
train_sample = Instance(
input_ids=DistTensorData(flow.tensor(tokens_np)),
attention_mask=DistTensorData(flow.tensor(padding_mask_np)),
tokentype_ids=DistTensorData(flow.tensor(tokentypes_np)),
lm_labels=DistTensorData(flow.tensor(labels_np), placement_idx=-1),
loss_mask=DistTensorData(flow.tensor(loss_mask_np), placement_idx=-1),
)
return train_sample
def truncate_segments(tokens, len_tokens, max_num_tokens, np_rng):
"""Truncates a sequences to a maximum sequence length."""
assert len_tokens > 0
if len_tokens <= max_num_tokens:
return False
while len_tokens > max_num_tokens:
if np_rng.random() < 0.5:
del tokens[0]
else:
tokens.pop()
len_tokens -= 1
return True
def create_tokens_and_tokentypes(tokens, cls_id, sep_id):
"""Add [CLS] and [SEP] and build tokentypes."""
# [CLS].
tokens.insert(0, cls_id)
# [SPE].
tokens.append(sep_id)
tokentypes = [0] * len(tokens)
return tokens, tokentypes
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 Style dataset."""
import collections
import numpy as np
import oneflow as flow
from libai.data.structures import DistTensorData, Instance
from ..data_utils import create_masked_lm_predictions, get_samples_mapping
class T5Dataset(flow.utils.data.Dataset):
"""
Dataset containing sentences for T5 training.
Args:
name: Name of dataset.
tokenizer: Tokenizer to use.
data_prefix (str): Path to the training dataset.
indexed_dataset: Indexed dataset to use.
max_seq_length (int, optional): Maximum length of the sequence passing into encoder.
All values are padded to this length. Defaults to 512.
max_seq_length_dec (int, optional): Maximum length of the sequence passing into decoder.
All values are padded to this length. Defaults to 128.
mask_lm_prob (float, optional): Probability to mask tokens. Defaults to 0.15.
max_preds_per_seq (int, optional): Maximum number of masked tokens in each sentence.
Defaults to None.
short_seq_prob (float, optional):
Probability of producing a short sequence. Defaults to 0.0.
seed (int, optional):
Seed for random number generator for reproducibility. Defaults to 1234.
"""
def __init__(
self,
name,
tokenizer,
indexed_dataset,
data_prefix,
max_num_samples,
masked_lm_prob,
max_seq_length,
max_seq_length_dec,
short_seq_prob,
seed,
):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
self.max_seq_length_dec = max_seq_length_dec
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(
self.indexed_dataset,
data_prefix,
None,
max_num_samples,
self.max_seq_length - 2, # account for added tokens
short_seq_prob,
self.seed,
self.name,
False,
)
# Vocab stuff.
self.tokenizer = tokenizer
tokenizer.add_tokens(
[tokenizer._bos_token, tokenizer._eos_token, *tokenizer._additional_special_tokens]
)
vocab = tokenizer.get_vocab()
inv_vocab = {v: k for k, v in vocab.items()}
self.vocab_id_list = list(inv_vocab.keys())
self.vocab_id_to_token_dict = inv_vocab
self.cls_id = vocab[tokenizer._cls_token]
self.sep_id = vocab[tokenizer._sep_token]
self.mask_id = vocab[tokenizer._mask_token]
self.pad_id = vocab[tokenizer._pad_token]
self.bos_id = vocab[tokenizer._bos_token]
self.eos_id = vocab[tokenizer._eos_token]
self.sentinel_tokens = [vocab[x] for x in tokenizer._additional_special_tokens]
assert len(self.sentinel_tokens) > 0
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_index, end_index, seq_length = self.samples_mapping[idx]
sample = []
for index in range(start_index, end_index):
sample.append(self.indexed_dataset[index])
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_training_sample(
self.tokenizer,
sample,
seq_length,
self.max_seq_length, # needed for padding
self.max_seq_length_dec,
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id,
self.sep_id,
self.mask_id,
self.pad_id,
self.masked_lm_prob,
np_rng,
self.bos_id,
self.eos_id,
self.sentinel_tokens,
)
def build_training_sample(
tokenizer,
sample,
target_seq_length,
max_seq_length,
max_seq_length_dec,
vocab_id_list,
vocab_id_to_token_dict,
cls_id,
sep_id,
mask_id,
pad_id,
masked_lm_prob,
np_rng,
bos_id=None,
eos_id=None,
sentinel_tokens=None,
):
"""Build training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
bos_id: start of decoder example id
eos_id: end of generation id
sentinel_tokens: unique value to be substituted for every replaced span
"""
assert target_seq_length <= max_seq_length
# flatten sentences into one list
tokens = [token for sentence in sample for token in sentence]
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
len(tokens) > max_num_tokens
tokens = tokens[:max_num_tokens]
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions(
tokenizer,
tokens,
vocab_id_list,
vocab_id_to_token_dict,
masked_lm_prob,
cls_id,
sep_id,
mask_id,
max_predictions_per_seq,
np_rng,
max_ngrams=10,
geometric_dist=True,
masking_style="t5",
)
# Padding.
(
tokens_enc,
tokens_dec_in,
labels,
enc_mask,
dec_mask,
enc_dec_mask,
loss_mask,
) = pad_and_convert_to_numpy(
tokens,
masked_positions,
masked_labels,
pad_id,
max_seq_length,
max_seq_length_dec,
masked_spans,
bos_id,
eos_id,
sentinel_tokens,
)
sample = Instance(
encoder_input_ids=DistTensorData(tokens_enc),
decoder_input_ids=DistTensorData(tokens_dec_in),
encoder_attn_mask=DistTensorData(enc_mask),
decoder_attn_mask=DistTensorData(dec_mask),
encoder_decoder_attn_mask=DistTensorData(enc_dec_mask),
lm_labels=DistTensorData(labels, placement_idx=-1),
loss_mask=DistTensorData(loss_mask, placement_idx=-1),
)
return sample
def pad_and_convert_to_numpy(
tokens,
masked_positions,
masked_labels,
pad_id,
max_seq_length,
max_seq_length_dec,
masked_spans=None,
bos_id=None,
eos_id=None,
sentinel_tokens=None,
):
"""Pad sequences and convert them to numpy."""
sentinel_tokens = collections.deque(sentinel_tokens)
t5_input = []
(t5_decoder_in, t5_decoder_out) = ([bos_id], [])
(start_index, end_index) = (0, None)
for span in masked_spans:
flag = sentinel_tokens.popleft()
# Append the same tokens in decoder input and output
t5_decoder_in.append(flag)
t5_decoder_in.extend(span.label)
t5_decoder_out.append(flag)
t5_decoder_out.extend(span.label)
end_index = span.index[0]
t5_input.extend(tokens[start_index:end_index])
t5_input.append(flag)
# the next start index is the token after the last span token
start_index = span.index[-1] + 1
# Add <eos> token to the t5_decoder_out
t5_decoder_out.append(eos_id)
# Add the remaining tokens to the t5 input
t5_input.extend(tokens[start_index:])
# assert (len(t5_input) - len(masked_spans)) + \
# (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
# Some checks.
# Encoder-side padding mask.
num_tokens = len(t5_input)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(masked_positions) == len(masked_labels)
# Tokens..
filler = [pad_id] * padding_length
tokens_enc = np.array(t5_input + filler, dtype=np.int64)
# Decoder-side padding mask.
num_tokens_dec = len(t5_decoder_in)
padding_length_dec = max_seq_length_dec - num_tokens_dec
assert padding_length_dec >= 0
filler_dec = [pad_id] * padding_length_dec
tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64)
# Create attention masks
enc_mask = make_attention_mask(tokens_enc, tokens_enc)
enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc)
dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in)
dec_mask = dec_mask * make_history_mask(tokens_dec_in)
# Labels mask.
labels = t5_decoder_out + ([-1] * padding_length_dec)
labels = np.array(labels, dtype=np.int64)
# Loss mask
loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec)
loss_mask = np.array(loss_mask, dtype=np.bool)
tokens_enc = flow.tensor(tokens_enc, dtype=flow.long)
tokens_dec_in = flow.tensor(tokens_dec_in, dtype=flow.long)
labels = flow.tensor(labels, dtype=flow.long)
enc_mask = flow.tensor(enc_mask, dtype=flow.bool)
dec_mask = flow.tensor(dec_mask, dtype=flow.bool)
enc_dec_mask = flow.tensor(enc_dec_mask, dtype=flow.bool)
loss_mask = flow.tensor(loss_mask, dtype=flow.bool)
return tokens_enc, tokens_dec_in, labels, enc_mask, dec_mask, enc_dec_mask, loss_mask
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def make_history_mask(block):
length = block.shape[0]
arange = np.arange(length)
history_mask = (
arange[
None,
]
<= arange[:, None]
)
history_mask = history_mask.astype(np.int64)
return history_mask
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .samplers import CyclicSampler, SingleRoundSampler
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from oneflow.utils.data import Sampler
class CyclicSampler(Sampler):
"""
This sampler supports cyclic sampling, and it is also compatible with
non-data parallelism and data parallelism.
Arguments:
dataset: dataset to be sampled.
micro_batch_size: batch size for per model instance.
global_batch_size is micro_batch_size times data_parallel_size.
shuffle: whether to shuffle the dataset.
consumed_samples: the number of samples that have been trained at the current time,
used for resuming training (default: ``0``).
data_parallel_rank: local rank for data parallelism.
data_parallel_size: the size of data parallelism.
seed: random seed, used for reproducing experiments (default: ``0``).
"""
def __init__(
self,
dataset,
micro_batch_size,
shuffle=False,
consumed_samples=0,
data_parallel_rank=0,
data_parallel_size=1,
seed=0,
):
self.dataset = dataset
self.data_size = len(self.dataset)
self.shuffle = shuffle
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_size = micro_batch_size
self.actual_batch_size = self.micro_batch_size * self.data_parallel_size
self.data_size_per_epoch = self.data_size // self.actual_batch_size * self.micro_batch_size
self.consumed_samples = consumed_samples
self.seed = seed
def __iter__(self):
"""divide the data into data_parallel_size buckets,
and shuffle it if `shuffle` is set to `True`.
Each processor samples from its own buckets and data_loader
will load the corresponding data.
"""
epoch = self.consumed_samples // self.data_size_per_epoch
current_epoch_samples = self.consumed_samples % self.data_size_per_epoch
batch = []
while True:
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * self.data_size_per_epoch
if self.shuffle:
generator = flow.Generator()
generator.manual_seed(self.seed + epoch)
random_idx = flow.randperm(self.data_size_per_epoch, generator=generator).tolist()
indices = [start_idx + x for x in random_idx[bucket_offset:]]
else:
seq_idx = flow.arange(self.data_size_per_epoch).tolist()
indices = [start_idx + x for x in seq_idx[bucket_offset:]]
epoch += 1
if hasattr(self.dataset, "supports_prefetch") and self.dataset.supports_prefetch:
self.dataset.prefetch(indices)
for idx in indices:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.actual_batch_size
yield batch
batch = []
current_epoch_samples = 0
def __len__(self):
return self.data_size
def set_consumed_samples(self, consumed_samples):
"""You can recover the training iteration by setting `consumed_samples`."""
self.consumed_samples = consumed_samples
def set_epoch(self, epoch):
"""Used for restoring training status."""
self.epoch = epoch
class SingleRoundSampler(Sampler):
"""
This sampler supports single round sampling, and it is also compatible with
non data parallelism and data parallelism.
Arguments:
dataset: dataset to be sampled.
micro_batch_size: batch size for per model instance, global_batch_size
is micro_batch_size times data_parallel_size.
shuffle: whether to shuffle the dataset.
data_parallel_rank: local rank for data parallelism.
data_parallel_size: the size of data parallelism.
seed: random seed, used for reproducing experiments (default: ``0``).
drop_last: whether to drop the remaining data (default: ``False``).
"""
def __init__(
self,
dataset,
micro_batch_size,
shuffle=False,
data_parallel_rank=0,
data_parallel_size=1,
seed=0,
drop_last=False,
):
self.dataset = dataset
self.data_size = len(self.dataset)
self.shuffle = shuffle
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_size = micro_batch_size
self.seed = seed
self.drop_last = drop_last
def __iter__(self):
bucket_size = self.data_size // self.data_parallel_size
remain = self.data_size % self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
if self.data_parallel_rank < remain:
bucket_size += 1
start_idx += min(self.data_parallel_rank, remain)
if self.shuffle:
generator = flow.Generator()
generator.manual_seed(self.seed)
random_idx = flow.randperm(bucket_size, generator=generator).tolist()
indices = [start_idx + x for x in random_idx]
else:
seq_idx = flow.arange(bucket_size).tolist()
indices = [start_idx + x for x in seq_idx]
if hasattr(self.dataset, "supports_prefetch") and self.dataset.supports_prefetch:
self.dataset.prefetch(indices)
batch = []
for idx in indices:
batch.append(idx)
if len(batch) == self.micro_batch_size:
yield batch
batch = []
if not self.drop_last:
if self.data_parallel_rank >= remain and remain > 0:
batch.append(0)
if len(batch) > 0:
yield batch
def __len__(self):
global_batch_size = self.micro_batch_size * self.data_parallel_size
if self.drop_last:
return self.data_size // global_batch_size
else:
return (self.data_size + global_batch_size - 1) // global_batch_size
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, List
import oneflow as flow
from libai.utils import distributed as dist
@dataclass
class DistTensorData:
tensor: flow.Tensor
sbp_list: list = field(default_factory=lambda: ["split_0", "broadcast"])
placement_idx: int = 0
# Tensor-like methods
def to_global(self, sbp=None, placement=None, device_type="cuda"):
if sbp is not None:
self.sbp = sbp
else:
sbp_list = []
for sbp in self.sbp_list:
sbp = sbp.split("_")
if len(sbp) > 1:
# split dim
assert sbp[0] == "split"
split_dim = int(sbp[1])
sbp_list.append(flow.sbp.split(split_dim))
else:
sbp_sign = sbp[0]
sbp_list.append(getattr(flow.sbp, sbp_sign))
self.sbp = dist.get_nd_sbp(sbp_list)
if placement is not None:
self.tensor = self.tensor.to_global(sbp=self.sbp, placement=placement)
else:
# Convert local tensor to global tensor with default setting,
# if the placement parameter is not provided.
# When enable pipeline parallel training,
# all the devices will be grouped into several device groups
# and the model will be split into several stages.
# Each stage will be placed on the corresponding device group.
# For those tensors to be used in the last stage,
# we first convert them to global tensor by only retain those on the device group 0,
# then transfer the result to the last stage.
# We do that to make sure that all the tensors used by the model are all generated
# by the fist device group, in case that each device group containg
# some random augmentations to the tensors without setting the same global seed.
main_placement = dist.get_layer_placement(0, device_type)
self.tensor = self.tensor.to_global(sbp=self.sbp, placement=main_placement)
if self.placement_idx != 0:
self.tensor = self.tensor.to_global(
placement=dist.get_layer_placement(self.placement_idx, device_type)
)
@staticmethod
def stack(distTensor_lists: List["DistTensorData"]) -> "DistTensorData":
if not isinstance(distTensor_lists[0].tensor, flow.Tensor):
raise TypeError(
"DistTensorData.tensor must be a flow.Tensor, but got {}. "
"Please check the return values of `__getitem__` in dataset.".format(
type(distTensor_lists[0].tensor)
)
)
assert len(distTensor_lists) > 0
if len(distTensor_lists) == 1:
# TODO(l1aoxingyu): add inplace unsqueeze
# distTensor_lists[0].tensor.unsqueeze_(0) # add batch dim
distTensor_lists[0].tensor = distTensor_lists[0].tensor.unsqueeze(0) # add batch dim
return distTensor_lists[0]
tensor_size = distTensor_lists[0].tensor.size()
sbp_list = distTensor_lists[0].sbp_list
placement_idx = distTensor_lists[0].placement_idx
tensors = []
for data in distTensor_lists:
assert (
data.tensor.size() == tensor_size
), f"tensor shape is not equal, {data.tensor.size()} != {tensor_size}"
assert (
data.sbp_list == sbp_list
), f"sbp_list is not equal, {data.sbp_list} != {sbp_list}!"
assert (
data.placement_idx == placement_idx
), f"placement_idx is not equal, {data.placement_idx} != {placement_idx}"
tensors.append(data.tensor)
tensors = flow.stack(tensors, dim=0)
ret = DistTensorData(tensors, sbp_list=sbp_list, placement_idx=placement_idx)
return ret
class Instance:
"""
This class represents a instance with metadata as attributes.
It stores the attributes of an instance (e.g., image, tokens) as "fields".
all other (non-filed) attributes of this class are considered private:
they must start with '_' and are not modifiable by a user.
Some basic usage:
1. Set/get/check a field:
.. code-block:: python
instance.tokens = Metadata(...)
instance.mask = Metadata(...)
print(instance.tokens)
print(instance.has("mask")) # True
2. ``len(instance)`` returns the number of instance
"""
def __init__(self, **kwargs):
self._fields = OrderedDict()
for k, v in kwargs.items():
self.set(k, v)
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, val)
else:
self.set(name, val)
def __getattr__(self, name: str):
if name == "_fields" or name not in self._fields:
raise AttributeError(f"Cannot find field '{name}' in the given Instance!")
return self._fields[name]
def set(self, name: str, value: Any):
"""
Set the field named `name` to `value`.
"""
self._fields[name] = value
def has(self, name: str):
return name in self._fields
def remove(self, name: str):
del self._fields[name]
def get(self, name: str):
return self._fields[name]
def get_fields(self):
return self._fields
def __len__(self):
return len(self._fields.keys())
def __iter__(self):
raise NotImplementedError("`Instances` object is not iterable!")
@staticmethod
def stack(instance_lists: List["Instance"]) -> "Instance":
assert all(isinstance(i, Instance) for i in instance_lists)
assert len(instance_lists) > 0
ret = Instance()
for k in instance_lists[0]._fields.keys():
values = [i.get(k) for i in instance_lists]
v0 = values[0]
if isinstance(v0, flow.Tensor):
values = flow.stack(values, dim=0)
elif isinstance(v0, list):
pass
elif hasattr(type(v0), "stack"):
values = type(v0).stack(values)
else:
raise ValueError("Unsupported type {} for stack.".format(type(v0)))
ret.set(k, values)
return ret
def __str__(self):
s = self.__class__.__name__ + "("
s += "fields=[{}]".format(", ".join((f"{k}: {v}" for k, v in self._fields.items())))
return s
__repr__ = __str__
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .default import DefaultTrainer, default_setup
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
import time
from collections import OrderedDict
from typing import Callable, Optional
import oneflow as flow
from omegaconf import OmegaConf
from termcolor import colored
from libai.config import LazyConfig, instantiate, try_get_key
from libai.data import Instance
from libai.engine import hooks
from libai.engine.trainer import EagerTrainer, GraphTrainer, TrainerBase
from libai.evaluation import inference_on_dataset, print_csv_format
from libai.models import build_graph, build_model
from libai.optim import build_optimizer
from libai.scheduler import build_lr_scheduler
from libai.tokenizer import build_tokenizer
from libai.utils import distributed as dist
from libai.utils.checkpoint import Checkpointer
from libai.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from libai.utils.logger import setup_logger
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/defaults.py
# --------------------------------------------------------
def _highlight(code, filename):
try:
import pygments
except ImportError:
return code
from pygments.formatters import Terminal256Formatter
from pygments.lexers import Python3Lexer, YamlLexer
lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
return code
def _check_batch_size(cfg):
train_micro_batch_size = try_get_key(cfg, "train.train_micro_batch_size", default=None)
global_batch_size = try_get_key(cfg, "train.global_batch_size", default=None)
num_accumulation_steps = try_get_key(cfg, "train.num_accumulation_steps", default=None)
if train_micro_batch_size is not None and global_batch_size is not None:
if num_accumulation_steps is None:
if global_batch_size % (train_micro_batch_size * dist.get_data_parallel_size()) != 0:
raise ValueError(
f"global_batch_size {global_batch_size} must be divisible by "
"train_micro_batch_size * data_parallel_size "
f"({train_micro_batch_size} * {dist.get_data_parallel_size()})"
)
cfg.train.num_accumulation_steps = global_batch_size // (
train_micro_batch_size * dist.get_data_parallel_size()
)
else:
if (
global_batch_size
!= train_micro_batch_size * dist.get_data_parallel_size() * num_accumulation_steps
):
raise ValueError(
f"global_batch_size {global_batch_size} must equal to "
"train_micro_batch_size * data_parallel_size * num_accumulation_steps "
f"({train_micro_batch_size} * {dist.get_data_parallel_size()} * {num_accumulation_steps})" # noqa
)
elif train_micro_batch_size is not None and global_batch_size is None:
if num_accumulation_steps is None:
cfg.train.num_accumulation_steps = 1
cfg.train.global_batch_size = (
train_micro_batch_size
* dist.get_data_parallel_size()
* cfg.train.num_accumulation_steps
)
elif train_micro_batch_size is None and global_batch_size is not None:
if num_accumulation_steps is None:
cfg.train.num_accumulation_steps = 1
if (
global_batch_size % (dist.get_data_parallel_size() * cfg.train.num_accumulation_steps)
!= 0
):
raise ValueError(
f"global_batch_size {global_batch_size} must be divisible by "
"data_parallel_size * num_accumulation_steps "
f"({dist.get_data_parallel_size()} * {cfg.train.num_accumulation_steps})"
)
cfg.train.train_micro_batch_size = global_batch_size // (
dist.get_data_parallel_size() * cfg.train.num_accumulation_steps
)
else:
raise ValueError("train_micro_batch_size and global_batch_size must be set either")
# Set total training samples.
cfg.train.samples = cfg.train.train_iter * cfg.train.global_batch_size
def _compile_dependencies():
logger = logging.getLogger(__name__)
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if dist.get_local_rank() == 0:
start_time = time.time()
logger.info("> compiling dataset index builder ...")
from libai.data.data_utils import compile_helper
compile_helper()
logger.info(
">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds".format(time.time() - start_time)
)
dist.synchronize()
if dist.get_local_rank() == 0:
logger.info(
">>> done with compiling. "
"Compilation time: {:.3f} seconds".format(time.time() - start_time)
)
def default_setup(cfg, args):
"""
Perform some basic common setups at the beginning of a job, including:
1. Set up the libai logger
2. Log basic information about environment, cmdline arguments, and config
3. Setup the distributed environment
4. Setup tokenizer if it's an NLP related task
5. Check batch_size
6. Backup the config to the output directory
7. Compile dependencies
Args:
args (argparse.NameSpace): the command line arguments to be logged
"""
output_dir = try_get_key(cfg, "train.output_dir")
if dist.is_main_process() and output_dir:
os.makedirs(output_dir, exist_ok=True)
cfg.train.resume = args.resume
rank = dist.get_rank()
logger = setup_logger(output_dir, distributed_rank=rank)
logger.info("Rank of current process: {}. World size: {}".format(rank, dist.get_world_size()))
logger.info("Command line arguments: " + str(args))
if hasattr(args, "config_file") and args.config_file != "":
logger.info(
"Contents of args.config_file={}:\n{}".format(
args.config_file,
_highlight(open(args.config_file, "r").read(), args.config_file),
)
)
dist.setup_dist_util(cfg.train.dist)
_check_batch_size(cfg)
if dist.is_main_process() and output_dir:
# Note: some of our scripts may expect the existence of
# config.yaml in output directory
path = os.path.join(output_dir, "config.yaml")
LazyConfig.save(cfg, path)
logger.info("Full config saved to {}".format(path))
flow.boxing.nccl.set_fusion_threshold_mbytes(
try_get_key(cfg, "train.nccl_fusion_threshold_mb", default=16)
)
flow.boxing.nccl.set_fusion_max_ops_num(
try_get_key(cfg, "train.nccl_fusion_max_ops", default=24)
)
_compile_dependencies()
class DefaultTrainer(TrainerBase):
"""
A trainer with default training logic. Compared to `TrainerBase`, it
also contains the following logic:
1. Create model, optimizer, scheduler, dataloader from the given config.
2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
3. Register a few common hooks defined by the config.
With standard features, it is created to simplify the **standard model training workflow** and
reduce code boilerplate for users who only need the standard training workflow.
It means this class makes **many assumptions** about your training logic that
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
:class:`TrainerBase` are too much for research.
The code of this class has been annotated about restrictive assumptions it made.
When they do not work for you, you're encouraged to:
1. Overwrite methods of this class, OR:
2. Use :class:`TrainerBase`, which only does minimal SGD training and
nothing else. You can then add your own hooks if needed. OR:
3. Write your own training loop similar to ``tools/train_net.py``.
Also note that the behavior of this class, like other functions/classes in
this file, is not stable, since it is meant to represent the "common default behavior".
It is only guaranteed to work well with the standard models and training workflow in libai.
To obtain more stable behavior, write your own training logic with other public APIs.
Examples:
.. code-block:: python
trainer = DefaultTrainer(cfg)
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
trainer.train()
Attributes:
scheduler:
checkpointer (Checkpointer):
cfg (omegaconf.dictconfig.DictConfig):
"""
def __init__(self, cfg):
"""
Args:
cfg (omegaconf.dictconfig.DictConfig):
"""
super().__init__()
self.cfg = cfg
logger = logging.getLogger("libai")
# setup_logger is not called for LiBai
if not logger.isEnabledFor(logging.INFO):
setup_logger()
# Initialize tokenizer
self.tokenizer = self.build_tokenizer(cfg)
self.start_iter = 0
if cfg.train.resume:
save_file = os.path.join(cfg.train.output_dir, "last_checkpoint")
try:
with open(save_file, "r") as f:
last_saved = f.read().strip()
assert (
last_saved != "model_final"
), "model training has finished, check your model in train.output_dir"
self.start_iter = int(last_saved.split("_")[-1]) + 1
except IOError:
# If file doesn't exist, maybe because it has just been deleted.
# We just set start_iter to 0.
self.start_iter = 0
if cfg.graph.enabled:
cfg.dataloader.consumed_samples = self.start_iter * cfg.train.global_batch_size
else:
cfg.dataloader.consumed_samples = (
self.start_iter * cfg.train.global_batch_size // cfg.train.num_accumulation_steps
)
self.train_loader = None
self.test_loader = []
train_loader, val_loader, test_loader = self.build_train_loader(cfg, self.tokenizer)
self.train_loader = train_loader
if val_loader is not None:
self.test_loader.append(val_loader)
if test_loader is not None:
self.test_loader.append(test_loader)
self.test_loader.extend(self.build_test_loader(cfg, self.tokenizer))
if cfg.train.rdma_enabled:
# set rdma
flow.env.init_rdma()
# Automatically scale the hyperparams
self.auto_scale_hyperparams(cfg, self.train_loader)
# Assume these objects must be constructed in this order.
dist.synchronize()
start_time = time.time()
logger.info("> Start building model...")
self.model = self.build_model(cfg)
dist.synchronize()
logger.info(
">>> done with building model. "
"Building time: {:.3f} seconds".format(time.time() - start_time)
)
self.optimizer = self.build_optimizer(cfg, self.model)
self.lr_scheduler = self.build_lr_scheduler(cfg, self.optimizer)
if cfg.graph.enabled:
self.graph_train = self.build_graph(
cfg, self.model, self.optimizer, self.lr_scheduler, is_train=True
)
self.graph_eval = self.build_graph(cfg, self.model, is_train=False)
self._trainer = GraphTrainer(
self.graph_train, self.train_loader, cfg.train.num_accumulation_steps
)
else:
self._trainer = EagerTrainer(
self.model, self.train_loader, self.optimizer, cfg.train.num_accumulation_steps
)
# Assume no other objects need to be checkpointed.
# We can later make it checkpoint the stateful hooks
if cfg.graph.enabled:
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
self.model,
cfg.train.output_dir,
# In static graph mode, optimizer and scheduler state_dict will
# be saved with graph.state_dict().
graph=self.graph_train,
# We print lr by `LRScheduler` hook, so we need to save/load eager lr_scheduler,
# otherwise, lr will be reset to initial state when resuming training.
lr_scheduler=self.lr_scheduler,
)
else:
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
self.model,
cfg.train.output_dir,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
)
# Loading checkpoint before dataloader construction, because
# dataloader needs to know the consumed iterations from
# the last breakpoint.
self.resume_or_load(cfg.train.resume)
cfg.train.start_iter = self.start_iter
# global_batch_size = micro_batch_size * num_gpus * num_accumulation_steps
# When using gradient accumulation in graph mode, each run_step
# handle `global_batch_size` samples.
# When using gradient accumulation in eager mode, each run_step just handle
# `micro_batch_size * num_gpus` samples, so we need to divide `num_accumulation_steps`
# to get the actual `batch_size` for computing `throughput` and `consumed_samples`
self.global_batch_size = (
cfg.train.global_batch_size
if cfg.graph.enabled
else cfg.train.global_batch_size // cfg.train.num_accumulation_steps
)
self.max_iter = cfg.train.train_iter
self.register_hooks(self.build_hooks())
def resume_or_load(self, resume=True):
"""
If `resume==True` and `cfg.train.output_dir` contains the last checkpoint (defined by
a `last_checkpoint` file), resume from the file. Resuming means loading all
available states (eg. optimizer and scheduler) and update iteration counter
from the checkpoint. ``cfg.train.load_weight`` will not be used.
Otherwise, this is considered as an independent training. The method will load model
weights from the file ``cfg.train.load_weight`` (but will not load other states) and start
from iteration 0.
Args:
resume (bool): whether to do resume or not
"""
weight_path = self.cfg.train.load_weight
assert isinstance(
weight_path, str
), f"cfg.train.load_weight:{self.cfg.train.load_weight} must be string"
if resume:
assert self.checkpointer.has_checkpoint()
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
assert self.start_iter == (
self.checkpointer.resume_or_load(None, resume=True).get("iter", -1) + 1
)
elif len(weight_path) != 0:
assert os.path.isdir(
weight_path
), f"cfg.train.load_weight:{self.cfg.train.load_weight} must be directory"
self.checkpointer.load(weight_path, checkpointables=[])
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(), # for beauty lr scheduler printer in `nn.Graph` mode
hooks.PeriodicCheckpointer(self.checkpointer, self.cfg.train.checkpointer.period),
]
if self.cfg.train.evaluation.enabled:
assert self.cfg.train.evaluation.eval_iter > 0, "run_iter must be positive number"
def test_and_save_results():
model = self.graph_eval if self.cfg.graph.enabled else self.model
self._last_eval_results = self.test(self.cfg, self.test_loader, model)
return self._last_eval_results
ret.append(hooks.EvalHook(self.cfg.train.evaluation.eval_period, test_and_save_results))
ret.append(
hooks.BestCheckpointer(
self.cfg.train.evaluation.eval_period,
self.checkpointer,
val_metric=try_get_key(
self.cfg, "train.evaluation.eval_metric", default="Acc@1"
),
mode=try_get_key(self.cfg, "train.evaluation.eval_mode", default="max"),
)
)
if dist.is_main_process():
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), self.cfg.train.log_period))
return ret
def build_writers(self):
"""
Build a list of writers to be used. By default it contains
writers that write metrics to the screen,
a json file, and a tensorboard event file respectively.
If you'd like a different list of writers, you can overwrite it in
your trainer.
Returns:
list[EventWriter]: a list of :class:`EventWriter` objects.
It is now implemented by:
.. code-block:: python
return [
CommonMetricPrinter(self.global_batch_size, self.max_iter),
JSONWriter(os.path.join(self.cfg.train.output_dir, "metrics.json")),
TensorboardXWriter(self.cfg.train.output_dir),
]
"""
# Assume the default print/log frequency.
return [
# It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(self.global_batch_size, self.max_iter),
JSONWriter(os.path.join(self.cfg.train.output_dir, "metrics.json")),
TensorboardXWriter(self.cfg.train.output_dir),
]
def train(self):
"""
Run training.
Returns:
OrderedDict of results, if evaluation is enabled. Otherwise None.
"""
super().train(self.start_iter, self.max_iter)
def run_step(self):
self._trainer.iter = self.iter
self._trainer.run_step(self.get_batch, self.cfg.train.input_placement_device)
@classmethod
def get_batch(
cls,
data: Instance,
input_placement_device: str = "cuda",
mixup_func: Optional[Callable] = None,
):
"""
Convert batched local tensor to distributed tensor for model step running.
If you want to do something with batched data before model, (e.g. mixup),
you can rewrite this function.
"""
if isinstance(data, flow.utils.data._utils.worker.ExceptionWrapper):
data.reraise()
if mixup_func is not None:
images, labels = mixup_func(
data.get("images").tensor.cuda(),
data.get("labels").tensor.cuda(),
)
data.get("images").tensor = images
data.get("labels").tensor = labels
ret_dict = {}
for key, value in data.get_fields().items():
value.to_global(device_type=input_placement_device)
ret_dict[key] = value.tensor
return ret_dict
@classmethod
def build_tokenizer(cls, cfg):
"""
Returns:
libai.tokenizer.PreTrainedTokenizer:
It now calls :func:`libai.tokenizer.build_tokenizer`.
"""
tokenizer = None
if try_get_key(cfg, "tokenization") is not None:
tokenizer = build_tokenizer(cfg.tokenization)
# FIXME(lxy): In case model is not defined with cfg, the `vocab_size` can be
# accessed by `model.vocab_size`.
if try_get_key(cfg, "model.cfg.vocab_size", default=None) is not None:
# In case the model does not need vocab_size as argument
multiple = (
cfg.tokenization.make_vocab_size_divisible_by
* cfg.train.dist.tensor_parallel_size
)
cfg.model.cfg.vocab_size = tokenizer.padded_vocab_size(multiple)
return tokenizer
@classmethod
def build_model(cls, cfg):
"""
Returns:
flow.nn.Module:
It now calls :func:`libai.models.build_model`.
Overwrite it if you'd like a different model.
"""
assert try_get_key(cfg, "model") is not None, "cfg must contain `model` namespace"
# Set model fp16 option because of embedding layer `white_identity` manual
# insert for amp training if provided.
if try_get_key(cfg.model, "cfg.amp_enabled") is not None:
cfg.model.cfg.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled
# In case some model define without cfg keyword.
elif try_get_key(cfg.model, "amp_enabled") is not None:
cfg.model.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled
model = build_model(cfg.model)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
model._apply(dist.convert_to_distributed_default_setting)
return model
@classmethod
def build_graph(cls, cfg, model, optimizer=None, lr_scheduler=None, is_train=True):
assert try_get_key(cfg, "graph") is not None, "cfg must contain `graph` namespace"
graph = build_graph(cfg, model, optimizer, lr_scheduler, is_train)
debug_graph = try_get_key(cfg, "graph.debug", default=-1)
if debug_graph >= 0:
logger = logging.getLogger(__name__)
logger.info("Graph debug mode on, automatically output debug info.")
graph.debug(cfg.graph.debug)
return graph
@classmethod
def build_optimizer(cls, cfg, model):
"""
Returns:
flow.optim.Optimizer:
It now calls :func:`libai.optim.build_optimizer`.
Overwrite it if you'd like a different optimizer.
"""
assert try_get_key(cfg, "optim") is not None, "cfg must contain `optim` namespace"
return build_optimizer(cfg.optim, model)
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`libai.scheduler.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
assert (
try_get_key(cfg, "train.scheduler") is not None
), "cfg.train must contain `scheduler` namespace"
return build_lr_scheduler(cfg.train.scheduler, optimizer)
@classmethod
def build_train_loader(cls, cfg, tokenizer=None):
"""
Returns:
iterable
It now calls :func:`libai.data.build_train_valid_test_loader`.
Overwrite it if you'd like a different data loader.
"""
assert (
try_get_key(cfg, "dataloader.train") is not None
), "cfg must contain `dataloader.train` namespace"
logger = logging.getLogger(__name__)
logger.info("Prepare training, validating, testing set")
if cfg.graph.enabled:
# In static graph mode, data will be sliced in nn.Graph automatically,
# dataloader will get micro-batch-size and data will be concated
# in graph_trainer.run_step to get mini-batch-size.
cfg.dataloader.train.train_batch_size = cfg.train.train_micro_batch_size
else:
# In eager mode, gradient accumulation will act like PyTorch, so dataloader
# will get micro-batch-size
cfg.dataloader.train.train_batch_size = cfg.train.train_micro_batch_size
cfg.dataloader.train.test_batch_size = cfg.train.test_micro_batch_size
cfg.dataloader.train.seed = cfg.train.seed
# used by nlp dataloader
if hasattr(cfg.dataloader.train, "train_val_test_num_samples"):
eval_iter = (
(cfg.train.train_iter // cfg.train.evaluation.eval_period + 1)
* cfg.train.evaluation.eval_iter
if cfg.train.evaluation.enabled
# samples for test_dataset must be larger than 0 even if there is no evaluation
else 1
)
test_iter = cfg.train.evaluation.eval_iter if cfg.train.evaluation.enabled else 1
cfg.dataloader.train.train_val_test_num_samples = [
int(cfg.train.samples),
int(eval_iter * cfg.train.test_micro_batch_size * dist.get_data_parallel_size()),
int(test_iter * cfg.train.test_micro_batch_size * dist.get_data_parallel_size()),
]
if OmegaConf.is_list(cfg.dataloader.train.dataset):
for dataset in cfg.dataloader.train.dataset:
if hasattr(dataset, "seed"):
dataset.seed = cfg.train.seed
else:
dataset = cfg.dataloader.train.dataset
if hasattr(dataset, "seed"):
dataset.seed = cfg.train.seed
# Set tokenizer for each dataset
if tokenizer:
if OmegaConf.is_list(cfg.dataloader.train.dataset):
for dataset in cfg.dataloader.train.dataset:
dataset.tokenizer = tokenizer
else:
cfg.dataloader.train.dataset.tokenizer = tokenizer
train_loader, valid_loader, test_loader = instantiate(
cfg.dataloader.train, _recursive_=False
)
return train_loader, valid_loader, test_loader
@classmethod
def build_test_loader(cls, cfg, tokenizer=None):
"""
Returns:
iterable
It now calls :func:`libai.data.build_image_test_loader` for CV tasks
or :func:`libai.data.build_nlp_test_loader` for NLP tasks.
Overwrite it if you'd like a different data loader.
"""
# If there is no test_loader, just return []
if not try_get_key(cfg, "dataloader.test", default=False):
return []
logger = logging.getLogger(__name__)
logger.info("Prepare testing set")
assert OmegaConf.is_list(
cfg.dataloader.test
), f"dataloader.test must be list but got type of {type(cfg.dataloader.test)}"
for i in range(len(cfg.dataloader.test)):
cfg.dataloader.test[i].test_batch_size = cfg.train.test_micro_batch_size
cfg.dataloader.test[i].seed = cfg.train.seed # set seed
if tokenizer:
cfg.dataloader.test[i].dataset.tokenizer = tokenizer
# list[dataloader1, dataloader2, ...]
test_loader = instantiate(cfg.dataloader.test, _recursive_=False)
return test_loader
@classmethod
def auto_scale_hyperparams(cls, cfg, data_loader):
logger = logging.getLogger(__name__)
log_info = ""
# Get or set default iteration cfg
train_iter = try_get_key(cfg, "train.train_iter", default=0)
train_epoch = try_get_key(cfg, "train.train_epoch", default=0)
warmup_ratio = try_get_key(cfg, "train.warmup_ratio", default=0)
assert (
warmup_ratio < 1 and warmup_ratio >= 0
), "warmup_ratio must be in [0, 1) that presents the ratio of warmup iter to the train iter"
# Automatically scale iteration num depend on the settings
# The total iters in one epoch is `len(dataset) / global_batch_size`
cfg.train.train_iter = max(
math.ceil(len(data_loader.dataset) * train_epoch / cfg.train.global_batch_size),
train_iter,
)
cfg.train.warmup_iter = math.ceil(cfg.train.train_iter * cfg.train.warmup_ratio)
if not cfg.graph.enabled:
# In eager mode, dataloader only get micro-batch-size each iter,
# which is mini-batch-size // num_accumulation, so scale `train_iter`
# and `warmup_iter` to be consistent with static graph mode.
cfg.train.train_iter *= cfg.train.num_accumulation_steps
cfg.train.warmup_iter *= cfg.train.num_accumulation_steps
log_info += "Auto-scaling the config to train.train_iter={}, train.warmup_iter={}".format(
cfg.train.train_iter, cfg.train.warmup_iter
)
# Automatically scale the milestones
if try_get_key(cfg, "train.scheduler.milestones"):
if len(
[
milestone
for milestone in cfg.train.scheduler.milestones
if milestone < 0 or milestone >= 1
]
):
raise ValueError(
"milestones should be a list of increasing ratio in [0, 1), but got {}".format(
cfg.train.scheduler.milestones
)
)
cfg.train.scheduler.milestones = [
int(milestone * cfg.train.train_iter)
for milestone in cfg.train.scheduler.milestones
]
log_info += f", scheduler milestones={cfg.train.scheduler.milestones}"
logger.info(log_info)
# Global scheduler cfg
cfg.train.scheduler.warmup_iter = cfg.train.warmup_iter
cfg.train.scheduler.max_iter = cfg.train.train_iter
# train iter per epoch
iter_per_epoch = len(data_loader.dataset) // cfg.train.global_batch_size
# rescale eval period
if try_get_key(cfg, "train.evaluation.eval_after_n_epoch"):
cfg.train.evaluation.eval_period = (
iter_per_epoch * cfg.train.evaluation.eval_after_n_epoch
)
logger.info(
f"Auto-scaling the config "
f"train.evaluation.eval_after_n_epoch={cfg.train.evaluation.eval_after_n_epoch} "
f"to train.evaluation.eval_period={cfg.train.evaluation.eval_period}"
)
# rescale save model period
if try_get_key(cfg, "train.checkpointer.save_model_after_n_epoch"):
cfg.train.checkpointer.period = (
iter_per_epoch * cfg.train.checkpointer.save_model_after_n_epoch
)
logger.info(
f"Auto-scaling the config "
f"train.checkpointer.save_model_after_n_epoch="
f"{cfg.train.checkpointer.save_model_after_n_epoch} "
f"to train.checkpointer.period={cfg.train.checkpointer.period}"
)
@classmethod
def build_evaluator(cls, cfg):
evaluator = instantiate(cfg.train.evaluation.evaluator)
return evaluator
@classmethod
def test(cls, cfg, test_loaders, model, evaluator=None):
"""
Evaluate the given model. The given model is expected to already contain
weights to evaluate.
Args:
cfg (CfgNode):
test_loaders: list [dataloader1, dataloader2, ...]
model (nn.Graph):
evaluators (list[DatasetEvaluator] or None): if None, will call
:meth:`build_evaluator`. Otherwise, must have the same length as
``cfg.DATASETS.TEST``.
Returns:
dict: a dict of result metrics
"""
logger = logging.getLogger(__name__)
# TODO: support multi evaluator
# if isinstance(evaluators, DatasetEvaluator):
# evaluators = [evaluators]
test_batch_size = cfg.train.test_micro_batch_size * dist.get_data_parallel_size()
evaluator = cls.build_evaluator(cfg) if not evaluator else evaluator
results = OrderedDict()
for idx, data_loader in enumerate(test_loaders):
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
dataset_name = type(data_loader.dataset).__name__
# TODO: support multi evaluator
# if evaluators is not None:
# evaluator = evaluators[idx]
# else:
# try:
# evaluator = cls.build_evaluator(cfg)
# except NotImplementedError:
# logger.warn(
# "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
# "or implement its `build_evaluator` method."
# )
# results[dataset_name] = {}
# continue
results_i = inference_on_dataset(
model,
data_loader,
test_batch_size,
cfg.train.evaluation.eval_iter,
cls.get_batch,
cfg.train.input_placement_device,
evaluator,
)
results[dataset_name] = results_i
if dist.is_main_process():
assert isinstance(
results_i, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results_i
)
logger.info(
"Evaluation results for {} in csv format:".format(
colored(dataset_name, "green")
)
)
print_csv_format(results_i)
if len(results) == 1:
results = list(results.values())[0]
return results
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import logging
import math
import operator
import time
from collections import Counter
import oneflow as flow
from libai.evaluation import flatten_results_dict
from libai.utils import distributed as dist
from libai.utils.checkpoint import Checkpointer
from libai.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
from libai.utils.events import EventWriter
from libai.utils.timer import Timer
from .trainer import HookBase
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/hooks.py
# --------------------------------------------------------
"""
Implement some common hooks.
"""
logger = logging.getLogger(__name__)
class CallbackHook(HookBase):
"""
Create a hook using callback functions provided by the user.
"""
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
"""
Each argument is a function that takes one argument: the trainer.
"""
self._before_train = before_train
self._before_step = before_step
self._after_step = after_step
self._after_train = after_train
def before_train(self):
if self._before_train:
self._before_train(self.trainer)
def after_train(self):
if self._after_train:
self._after_train(self.trainer)
# The functions may be closures that hold reference to the trainer
# Therefore, delete them to avoid circular reference.
del self._before_train, self._after_train
del self._before_step, self._after_step
def before_step(self):
if self._before_step:
self._before_step(self.trainer)
def after_step(self):
if self._after_step:
self._after_step(self.trainer)
class IterationTimer(HookBase):
"""
Track the time spent for each iteration (each run_step call in the trainer).
Print a summary in the end of training.
This hook uses the time between the call to its :meth:`before_step`
and :meth:`after_step` methods.
Under the convention that :meth:`before_step` of all hooks should only
take negligible amount of time, the :class:`IterationTimer` hook should be
placed at the beginning of the list of hooks to obtain accurate timing.
"""
def __init__(self, warmup_iter=3):
"""
Args:
warmup_iter (int): the number of iterations at the beginning to exclude
from timing.
"""
self._warmup_iter = warmup_iter
self._step_timer = Timer()
def before_train(self):
self._start_time = time.perf_counter()
self._total_timer = Timer()
self._total_timer.pause()
def after_train(self):
total_time = time.perf_counter() - self._start_time
total_time_minus_hooks = self._total_timer.seconds()
hook_time = total_time - total_time_minus_hooks
num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
if num_iter > 0 and total_time_minus_hooks > 0:
# Speed is meaningful only after warmup
# NOTE this format is parsed by grep in some scripts
logger.info(
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
num_iter,
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
total_time_minus_hooks / num_iter,
)
)
logger.info(
"Total training time: {} ({} on hooks)".format(
str(datetime.timedelta(seconds=int(total_time))),
str(datetime.timedelta(seconds=int(hook_time))),
)
)
def before_step(self):
self._step_timer.reset()
self._total_timer.resume()
def after_step(self):
# +1 because we're in after_step
iter_done = self.trainer.iter - self.trainer.start_iter + 1
if iter_done >= self._warmup_iter:
sec = self._step_timer.seconds()
self.trainer.storage.put_scalars(time=sec)
else:
self._start_time = time.perf_counter()
self._total_timer.reset()
self._total_timer.pause()
class PeriodicWriter(HookBase):
"""
Write events to EventStorage periodically.
It is executed every ``period`` iterations and after the last iteration.
"""
def __init__(self, writers, period=20):
"""
Args:
writers (list[EventWriter]): a list of EventWriter objects
period (int):
"""
self._writers = writers
for w in writers:
assert isinstance(w, EventWriter), w
self._period = period
def after_step(self):
if (self.trainer.iter + 1) % self._period == 0 or (
self.trainer.iter == self.trainer.max_iter - 1
):
for writer in self._writers:
writer.write()
def after_train(self):
for writer in self._writers:
writer.close()
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
"""
Same as :class:`libai.utils.checkpoint.PeriodicCheckpointer`, but as a hook.
Note that when used as a hook,
it is unable to save additional data other than what's defined
by the given `checkpointer`.
It is executed every ``period`` iterations and after the last iteration.
"""
def before_train(self):
self.max_iter = self.trainer.max_iter
def after_step(self):
self.step(self.trainer.iter)
class BestCheckpointer(HookBase):
"""
Checkpoints best weights based off given metric.
This hook should be used in conjunction to and executed after the hook
that produces the metric, e.g. `EvalHook`.
"""
def __init__(
self,
eval_period: int,
checkpointer: Checkpointer,
val_metric: str,
mode: str = "max",
file_prefix: str = "model_best",
) -> None:
"""
Args:
eval_period (int): the period `EvalHook` is set to run.
checkpointer: the checkpointer object used to save checkpoints.
val_metric (str): validation metric to track for best checkpoint, e.g. "acc@1"
mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be
maximized or minimized, e.g. for "acc@1" it should be "max"
file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best"
"""
self._period = eval_period
self._val_metric = val_metric
assert mode in [
"max",
"min",
], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.'
if mode == "max":
self._compare = operator.gt
else:
self._compare = operator.lt
self._checkpointer = checkpointer
self._file_prefix = file_prefix
self.best_metric = None
self.best_iter = None
def _update_best(self, val, iteration):
if math.isnan(val) or math.isinf(val):
return False
self.best_metric = val
self.best_iter = iteration
return True
def _best_checking(self):
metric_tuple = self.trainer.storage.latest().get(self._val_metric)
flag = flow.zeros(1)
if dist.is_main_process():
if metric_tuple is None:
logger.warning(
f"Given val metric {self._val_metric} does not seem to be computed/stored. "
"Will not be checkpointed based on that."
)
else:
latest_metric, metric_iter = metric_tuple
if self.best_metric is None:
if self._update_best(latest_metric, metric_iter):
flag = flag + 1
logger.info(
f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps"
)
elif self._compare(latest_metric, self.best_metric):
flag = flag + 1
logger.info(
f"Saved best model as latest eval score for {self._val_metric} is "
f"{latest_metric:0.5f}, better than last best score "
f"{self.best_metric:0.5f} @ iteration {self.best_iter}."
)
self._update_best(latest_metric, metric_iter)
else:
logger.info(
f"Not saving as latest eval score for "
f"{self._val_metric} is {latest_metric:0.5f}, "
f"not better than best score {self.best_metric:0.5f} "
f"@ iteration {self.best_iter}."
)
dist.synchronize()
flag = flag.to_global(
sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cpu")
)
if flag.to_local().item() == 1:
self._checkpointer.save(f"{self._file_prefix}")
def after_step(self):
# same conditions as `EvalHook`
next_iter = self.trainer.iter + 1
if (
self._period > 0
and next_iter % self._period == 0
and next_iter != self.trainer.max_iter
):
self._best_checking()
def after_train(self):
# same conditions as `EvalHook`
if self.trainer.iter + 1 >= self.trainer.max_iter:
self._best_checking()
class EvalHook(HookBase):
"""
Run an evaluation function periodically, and at the end of training.
It is executed every ``eval_period`` iterations and after the last iteration.
"""
def __init__(self, eval_period, eval_function):
"""
Args:
eval_period (int): the period to run `eval_function`.
eval_function (callable): a function which takes no arguments, and
returns a nested dict of evaluation metrics.
Note:
This hook must be enabled in all or none workers.
If you would like only certain workers to perform evaluation,
give other workers a no-op function (`eval_function=lambda: None`).
"""
self._period = eval_period
self._func = eval_function
def _do_eval(self):
results = self._func()
if results:
assert isinstance(
results, dict
), "Eval function must return a dict. Got {} instead.".format(results)
flattened_results = flatten_results_dict(results)
# fixme: flatten_results_dict is not defined
for k, v in flattened_results.items():
try:
v = float(v)
except Exception:
raise ValueError(
"[EvalHook] eval_function should return a nested dict of float. "
"Got '{}: {}' instead.".format(k, v)
)
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
# Evaluation may take different time among workers.
# A barrier make them start the next iteration together.
dist.synchronize()
def after_step(self):
next_iter = self.trainer.iter + 1
if self._period > 0 and next_iter % self._period == 0:
# do the last eval in after_train
if next_iter != self.trainer.max_iter:
self._do_eval()
def after_train(self):
# This condition is to prevent the eval from running after a failed training
if self.trainer.iter + 1 >= self.trainer.max_iter:
self._do_eval()
# func is likely a closure that holds reference to the trainer
# therefore we clean it to avoid circular reference in the end
del self._func
class LRScheduler(HookBase):
"""
A hook which executes a oneflow builtin LR scheduler and summarizes the LR.
It is executed after every iteration.
"""
def __init__(self, optimizer=None, scheduler=None):
"""
Args:
optimizer (flow.optim.Optimizer):
scheduler (flow.optim.LRScheduler):
if a :class:`ParamScheduler` object, it defines the multiplier over the base LR
in the optimizer.
If any argument is not given, will try to obtain it from the trainer.
"""
self._optimizer = optimizer
self._scheduler = scheduler
def before_train(self):
self._optimizer = self._optimizer or self.trainer.optimizer
self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer)
@staticmethod
def get_best_param_group_id(optimizer):
# NOTE: some heuristics on what LR to summarize
# summarize the param group with most parameters
largest_group = max(len(g["params"]) for g in optimizer.state_dict()["param_groups"])
if largest_group == 1:
# If all groups have one parameter,
# then find the most common initial LR, and use it for summary
lr_count = Counter(
[g["_options"]["lr"] for g in optimizer.state_dict()["param_groups"]]
)
lr = lr_count.most_common()[0][0]
for i, g in enumerate(optimizer.state_dict()["param_groups"]):
if g["_options"]["lr"] == lr:
return i
else:
for i, g in enumerate(optimizer.state_dict()["param_groups"]):
if len(g["params"]) == largest_group:
return i
def after_step(self):
lr = self.scheduler.get_last_lr()[self._best_param_group_id]
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
self.scheduler.step()
@property
def scheduler(self):
return self._scheduler or self.trainer.lr_scheduler
def state_dict(self):
if isinstance(self.scheduler, flow.optim.lr_scheduler._LRScheduler):
return self.scheduler.state_dict()
return {}
def load_state_dict(self, state_dict):
if isinstance(self.scheduler, flow.optim.lr_scheduler._LRScheduler):
logger.info("Loading scheduler from state_dict ...")
self.scheduler.load_state_dict(state_dict)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import weakref
from typing import Callable, List, Mapping
import oneflow as flow
from libai.utils import distributed as dist
from libai.utils.events import EventStorage, get_event_storage
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/train_loop.py
# --------------------------------------------------------
class HookBase:
"""
Base class for hooks that can be registered with :class:`TrainerBase`.
Each hook can implement 4 methods. The way they are called is demonstrated
in the following snippet:
::
hook.before_train()
for iter in range(start_iter, max_iter):
hook.before_step()
trainer.run_step()
hook.after_step()
iter += 1
hook.after_train()
Notes:
1. In the hook method, users can access ``self.trainer`` to access more
properties about the context (e.g., model, current iteration, or config
if using :class:`DefaultTrainer`).
2. A hook that does something in :meth:`before_step` can often be
implemented equivalently in :meth:`after_step`.
If the hook takes non-trivial time, it is strongly recommended to
implement the hook in :meth:`after_step` instead of :meth:`before_step`.
The convention is that :meth:`before_step` should only take negligible time.
Following this convention will allow hooks that do care about the difference
between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
function properly.
"""
trainer: "TrainerBase" = None
"""
A weak reference to the trainer object. Set by the trainer when the hook is registered.
"""
def before_train(self):
"""
Called before the first iteration.
"""
def after_train(self):
"""
Called after the last iteration.
"""
def before_step(self):
"""
Called before each iteration.
"""
def after_step(self):
"""
Called after each iteration.
"""
class TrainerBase:
"""
Base class for iterative trainer with hooks.
The only assumption we made here is: the training runs in a loop.
A subclass can implement what the loop is.
We made no assumptions about the existence of dataloader, optimizer, model, etc.
Attributes:
iter(int): The current iteration.
start_iter(int): The iteration to start with.
By convention the minimum possible value is 0.
max_iter(int): The iteration to end training.
storage(EventStorage): An EventStorage that's opened during the course of training.
"""
def __init__(self):
self._hooks: List[HookBase] = []
self.iter: int = 0
self.start_iter: int = 0
self.max_iter: int
self.storage: EventStorage
def register_hooks(self, hooks):
"""
Register hooks to the trainer. The hooks are executed in the order
they are registered.
Args:
hooks (list[Optional[HookBase]]): list of hooks
"""
hooks = [h for h in hooks if h is not None]
for h in hooks:
assert isinstance(h, HookBase)
# To avoid circular reference, hooks and trainer cannot own each other.
# This normally does not matter, but will cause memory leak if the
# involved objects contain __del__:
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
h.trainer = weakref.proxy(self)
self._hooks.extend(hooks)
def train(self, start_iter: int, max_iter: int):
"""
Args:
start_iter, max_iter (int): See docs above
"""
logger = logging.getLogger(__name__)
logger.info("Starting training from iteration {}".format(start_iter))
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(self.start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
# self.iter == max_iter can be used by `after_train` to
# tell whether the training successfully finished or failed
# due to exceptions.
self.iter += 1
except Exception:
logger.exception("Exception during training:")
raise
finally:
self.after_train()
def before_train(self):
for h in self._hooks:
h.before_train()
def after_train(self):
for h in self._hooks:
h.after_train()
def before_step(self):
self.storage.iter = self.iter
for h in self._hooks:
h.before_step()
def after_step(self):
self.storage.samples = (self.iter + 1) * self.cfg.train.global_batch_size
for h in self._hooks:
h.after_step()
def run_step(self):
raise NotImplementedError
@staticmethod
def write_metrics(
loss_dict: Mapping[str, flow.Tensor],
data_time: float,
prefix: str = "",
) -> None:
"""
Args:
loss_dict (dict): dict of scalar losses
data_time (float): time taken by the dataloader iteration
prefix (str): prefix for logging keys
"""
# get metric value, remove it to rank0 cause logger.info only work in rank0
metrics_dict = {
k: dist.tensor_to_rank0(v, device="cpu", to_local=True) for k, v in loss_dict.items()
}
metrics_dict["data_time"] = data_time
# TODO: Gather metrics among all workers for logging
# all_metrics_dict = dist.gather(metrics_dict)
all_metrics_dict = metrics_dict
if dist.is_main_process():
storage = get_event_storage()
# data_time among workers can have high variance. The actual latency
# caused by data_time is the maximum among workers.
# data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
data_time = all_metrics_dict.pop("data_time")
storage.put_scalar("data_time", data_time)
# average the rest metrics
# metrics_dict = {
# k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
# }
metrics_dict = all_metrics_dict
total_losses_reduced = sum(v for k, v in metrics_dict.items() if "loss" in k)
storage.put_scalar("{}total_loss".format(prefix), total_losses_reduced)
if len(metrics_dict) > 1:
storage.put_scalars(**metrics_dict)
class EagerTrainer(TrainerBase):
"""
A simple eager trainer for the most common type of task:
single-cost single-optimizer single-data-source iterative optimization,
optionally using data-parallelism.
It assumes that in every step, you:
1. Compute the loss with a data from the data_loader.
2. Compute the gradients with the above loss.
3. Update the model with the optimizer.
All other tasks during training (checkpointing, logging, evaluation, LR schedule)
are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`.
If you want to do anything fancier than this,
either subclass TrainerBase and implement your own `run_step`,
or write your own training loop.
"""
def __init__(self, model, data_loader, optimizer, grad_acc_steps=1):
"""
Args:
model: a flow.nn.Module. Takes a data from data_loader and returns a
dict of losses.
data_loader: an iterable. Contains data to be used to call model.
optimizer: a flow optimizer.
"""
super().__init__()
# We set the model to training mode in the trainer.
# However it's valid to train a model that's in eval mode.
# If you want your model (or a submodule of it) to behave
# like evaluation during training, you can overwrite its train() method.
model.train()
self.model = model
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer
self.grad_acc_steps = grad_acc_steps
def run_step(self, get_batch: Callable, input_placement_device: str = "cuda"):
"""
Implement the standard training logic described above.
"""
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
# If you want to do something with the data, you can wrap the dataloader.
data = next(self._data_loader_iter)
data = get_batch(
data, input_placement_device, getattr(self.data_loader, "mixup_func", None)
)
data_time = time.perf_counter() - start
loss_dict = self.model(**data)
losses = sum(v for k, v in loss_dict.items() if "loss" in k) / self.grad_acc_steps
losses.backward()
self.write_metrics(loss_dict, data_time)
if (self.iter + 1) % self.grad_acc_steps == 0:
self.optimizer.clip_grad()
self.optimizer.step()
self.optimizer.zero_grad()
class GraphTrainer(TrainerBase):
"""
A simple graph trainer for training and evaluating models in a static graph mode.
"""
def __init__(self, graph, data_loader, grad_acc_steps=1):
super().__init__()
graph.model.train()
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.graph = graph
self.grad_acc_steps = grad_acc_steps
self._temp_data = None
self._temp_count = 0
def run_step(self, get_batch: Callable, input_placement_device: str = "cuda"):
"""
Implement the standard training logic described above.
"""
assert self.graph.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
while self._temp_count != self.grad_acc_steps:
# If you want to do something with the data, you can wrap the dataloader.
data = next(self._data_loader_iter)
self._temp_count += 1
if self._temp_data is None:
self._temp_data = data
else:
# In static graph mode, data will be sliced in nn.Graph automatically,
# for geting mini-batch_size, we concat local_tensor first.
for key, value in data.get_fields().items():
temp_value = self._temp_data.get(key)
self._temp_data.get(key).tensor = flow.cat(
(temp_value.tensor, value.tensor), dim=0
)
data = self._temp_data
self._temp_count = 0
self._temp_data = None
data = get_batch(
data, input_placement_device, getattr(self.data_loader, "mixup_func", None)
)
data_time = time.perf_counter() - start
# If you want to do something with the losses, you can wrap the model.
loss_dict = self.graph(**data)
# Add this because when set up gradient accumulations, graph will return
# an unpacked n-d tensor whose size is accumulation step
for key, value in loss_dict.items():
if "loss" in key:
loss_dict[key] = value.mean()
else:
# NOTE: only support scalar tensor currently
loss_dict[key] = value.sum()
self.write_metrics(loss_dict, data_time)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment