Unverified Commit ca6e75bc authored by BoxiangW's avatar BoxiangW Committed by GitHub
Browse files

[tutorial] edited hands-on practices (#1899)

* Add handson to ColossalAI.

* Change names of handsons and edit sequence parallel example.

* Edit wrong folder name

* resolve conflict

* delete readme
parent d9bf83e0
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataloaders."""
import torch
import random
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0):
"""Build dataloader given an input dataset."""
if dataset is None:
return None
# Megatron sampler
if dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
elif dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
else:
raise Exception('{} dataloader type is not supported.'.format(dataloader_type))
# Torch dataloader.
return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
class MegatronPretrainingSampler:
def __init__(self,
total_samples,
consumed_samples,
micro_batch_size,
data_parallel_rank,
data_parallel_size,
drop_last=True):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
This diff is collapsed.
import itertools
import random
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
rather than for training, since it is only built with a single epoch sample mapping.
"""
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
kwargs = dict(
name='full',
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=args.seq_length,
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles,
use_one_sent_docs=args.use_one_sent_docs
)
dataset = ICTDataset(**kwargs)
return dataset
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.rng = random.Random(self.seed)
self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return len(self.samples_mapping)
def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
sample_data = self.samples_mapping[idx]
start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
if self.use_titles:
title = self.title_dataset[int(doc_idx)]
title_pad_offset = 3 + len(title)
else:
title = None
title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
# randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context query_in_block_prob fraction of the time.
if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy()
else:
query = block.pop(rand_sent_idx)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
query_mask = make_attention_mask(query_tokens, query_tokens)
context_mask = make_attention_mask(context_tokens, context_tokens)
block_data = sample_data.as_array()
sample = {
'query_tokens': query_tokens,
'query_mask': query_mask,
'query_pad_mask': query_pad_mask,
'context_tokens': context_tokens,
'context_mask': context_mask,
'context_pad_mask': context_pad_mask,
'block_data': block_data,
}
return sample
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
title = self.title_dataset[int(doc_idx)]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def get_null_block(self):
"""Get empty block and title - used in REALM pretraining"""
block, title = [], []
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def concat_and_pad_tokens(self, tokens, title=None):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens = list(tokens)
if title is None:
tokens = [self.cls_id] + tokens + [self.sep_id]
else:
title = list(title)
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
assert len(tokens) <= self.max_seq_length
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return np.array(tokens), np.array(pad_mask)
# This file isn't really a formal automated test, it's just a place to
# put some code used during development and manual testing of
# indexed_dataset.
from megatron.data import indexed_dataset
from megatron.tokenizer import build_tokenizer
import argparse
import os
import sys
import torch
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../"))
def test_indexed_dataset(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args)
print(len(ds.doc_idx))
print(len(ds))
print(ds.doc_idx[-1])
if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds)))
if args.count > len(ds.doc_idx) - 1:
args.count = len(ds.doc_idx) - 1
for i in range(args.count):
start = ds.doc_idx[i]
end = ds.doc_idx[i + 1]
ids = ds[start:end]
print(f"Document {i}:")
print("--------------")
for s in ids:
assert len(s) > 0
l = s.data.tolist()
text = tokenizer.detokenize(l)
print(text)
print("---")
def test_indexed_dataset_get(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args)
size = ds.sizes[0]
print(f"size: {size}")
full = ds.get(0)
print(full)
# print(tokenizer.detokenize(full.data.tolist()))
print("---")
end = ds.get(0, offset=size - 10)
print(end)
# print(tokenizer.detokenize(end.data.tolist()))
start = ds.get(0, length=10)
print(start)
# print(tokenizer.detokenize(start.data.tolist()))
part = ds.get(0, offset=2, length=8)
print(part)
# print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
# # ds = AlbertDataset(idataset, tokenizer)
# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
# args.epochs, args.max_num_samples,
# args.masked_lm_prob, args.seq_length,
# args.short_seq_prob, args.seed)
# truncated = 0
# total = 0
# for i, s in enumerate(ds):
# ids = s['text']
# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
# print(tokens)
# if i >= args.count-1:
# exit()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files')
parser.add_argument('--dataset-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--count', type=int, default=10,
help='Number of samples/documents to print')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
parser.add_argument('--epochs', type=int, default=5,
help='Number of epochs to plan for')
parser.add_argument('--max-num-samples', type=int, default=None,
help='Maximum number of samples to plan for')
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
help='probability of masking tokens')
parser.add_argument('--seq-length', type=int, default=512,
help='maximum sequence length')
parser.add_argument('--short-seq-prob', type=float, default=0.1,
help='probability of creating a short sequence')
parser.add_argument('--seed', type=int, default=1234,
help='random seed')
args = parser.parse_args()
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
# test_albert_dataset(args)
test_indexed_dataset_get(args)
if __name__ == "__main__":
main()
#!/bin/bash
IMPL=cached
python ../preprocess_data.py \
--input test_samples.json \
--vocab vocab.txt \
--dataset-impl ${IMPL} \
--output-prefix test_samples_${IMPL} \
--workers 1 \
--log-interval 2
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tokenizer import build_tokenizer
_TOKENIZER = None
_PADDED_VOCAB_SIZE = -1
def initialize_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0):
tokenizer, padded_vocab_size = build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids)
global _TOKENIZER, _PADDED_VOCAB_SIZE
_TOKENIZER = tokenizer
_PADDED_VOCAB_SIZE = padded_vocab_size
def get_tokenizer():
global _TOKENIZER
return _TOKENIZER
def get_padded_vocab_size():
global _PADDED_VOCAB_SIZE
return _PADDED_VOCAB_SIZE
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron tokenizers."""
from abc import ABC
from abc import abstractmethod
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from .bert_tokenization import FullTokenizer as FullBertTokenizer
def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0):
"""Initialize tokenizer."""
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
print('> building {} tokenizer ...'.format(tokenizer_type),
flush=True)
# Select and instantiate the tokenizer.
if tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file,
lower_case=True,
vocab_extra_ids=vocab_extra_ids)
elif tokenizer_type == 'BertWordPieceCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file,
lower_case=False,
vocab_extra_ids=vocab_extra_ids)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(tokenizer_type))
# Add vocab size.
padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size)
return tokenizer, padded_vocab_size
def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128):
"""Pad vocab size so it is divisible by model parallel size and
still having GPU friendly size."""
after = orig_vocab_size
if gpc.is_initialized(ParallelMode.TENSOR):
multiple = make_vocab_size_divisible_by * gpc.get_world_size(ParallelMode.TENSOR)
else:
multiple = make_vocab_size_divisible_by
while (after % multiple) != 0:
after += 1
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
print(' > padded vocab (size: {}) with {} dummy tokens '
'(new size: {})'.format(
orig_vocab_size, after - orig_vocab_size, after), flush=True)
return after
class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""
def __init__(self, name):
self.name = name
super().__init__()
@property
@abstractmethod
def vocab_size(self):
pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
'tokenizer'.format(self.name))
@property
def sep(self):
raise NotImplementedError('SEP is not provided for {} '
'tokenizer'.format(self.name))
@property
def pad(self):
raise NotImplementedError('PAD is not provided for {} '
'tokenizer'.format(self.name))
@property
def eod(self):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
if lower_case:
name = 'BERT Lower Case'
else:
name = 'BERT Upper Case'
super().__init__(name)
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]']
self._additional_special_tokens = []
# (dsachan) Add BOS and EOS tokens
SPECIAL_TOKENS = {'eos_token': '[EOS]',
'bos_token': '[BOS]'}
self._bos_token = '[BOS]'
self.add_token(self._bos_token)
self._bos_token_id = self.vocab.get(self._bos_token)
self._eos_token = '[EOS]'
self.add_token(self._eos_token)
self._eos_token_id = self.vocab.get(self._eos_token)
# (dsachan) Add additional special tokens
# These can be used as sentinel tokens in T5 model inputs
additional_special_tokens = []
additional_special_tokens.extend(
["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)])
self.add_additional_special_tokens(additional_special_tokens)
def add_token(self, token):
if token not in self.vocab:
self.inv_vocab[self.vocab_size] = token
# self.vocab_size comes from len(vocab)
# and it will increase as we add elements
self.vocab[token] = self.vocab_size
def add_additional_special_tokens(self, tokens_list):
setattr(self, "additional_special_tokens", tokens_list)
for value in tokens_list:
self.add_token(value)
@property
def vocab_size(self):
return self.tokenizer.vocab_size()
@property
def vocab(self):
return self.tokenizer.vocab
@property
def inv_vocab(self):
return self.tokenizer.inv_vocab
def tokenize(self, text):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
def decode(self, ids):
tokens = self.tokenizer.convert_ids_to_tokens(ids)
return self.tokenizer.convert_tokens_to_string(tokens)
def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
non_pads = [t for t in tokens if t not in exclude_list]
result = ""
for s in non_pads:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
@property
def cls(self):
return self.cls_id
@property
def sep(self):
return self.sep_id
@property
def pad(self):
return self.pad_id
@property
def mask(self):
return self.mask_id
@property
def bos_token(self):
""" Beginning of sentence token id """
return self._bos_token
@property
def eos_token(self):
""" End of sentence token id """
return self._eos_token
@property
def additional_special_tokens(self):
""" All the additional special tokens you may want to use (list of strings)."""
return self._additional_special_tokens
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary."""
return self._bos_token_id
@property
def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary."""
return self._eos_token_id
@property
def additional_special_tokens_ids(self):
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
return [self.vocab.get(token) for token in self._additional_special_tokens]
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
import torch
import torch.nn as nn
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger
import torch.nn.functional as F
import torch.distributed as dist
from .cross_entropy import vocab_cross_entropy
class BertLoss(nn.Module):
def forward(self,
lm_loss,
sop_logits,
loss_mask,
sentence_order):
lm_loss_ = lm_loss.float()
loss_mask = loss_mask.float()
loss_mask_sum = loss_mask.sum()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1))
lm_loss /= loss_mask_sum
torch.distributed.all_reduce(
lm_loss,
group=gpc.get_group(ParallelMode.SEQUENCE)
)
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE)
else:
sop_loss = None
loss = lm_loss
return loss
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from .embedding import VocabEmbedding, Embedding
from .bert_layer import BertLayer
from .head import BertDualHead
from .preprocess import PreProcessor
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment