Commit b7f1b050 authored by Neel Kant's avatar Neel Kant
Browse files

Lint whole repo

parent c99fa80c
...@@ -26,6 +26,7 @@ class IdentityLayer(torch.nn.Module): ...@@ -26,6 +26,7 @@ class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0): def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__() super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size)) self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self): def forward(self):
return self.weight return self.weight
......
...@@ -13,20 +13,18 @@ ...@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import set_random_seed
from commons import IdentityLayer
from commons import print_separator
from commons import initialize_distributed
from mpu.cross_entropy import vocab_parallel_cross_entropy
import mpu
import torch.nn.functional as F
import torch
import random import random
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import torch.nn.functional as F
import mpu
from mpu.cross_entropy import vocab_parallel_cross_entropy
from commons import initialize_distributed
from commons import print_separator
from commons import IdentityLayer
from commons import set_random_seed
def torch_cross_entropy(batch_size, seq_length, vocab_size, def torch_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed): logits_scale, seed):
......
...@@ -13,18 +13,16 @@ ...@@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import print_separator
from commons import initialize_distributed
from mpu import data as data_utils
import mpu
import torch
import functools import functools
import operator import operator
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import mpu
from mpu import data as data_utils
from commons import initialize_distributed
from commons import print_separator
def test_boradcast_data(model_parallel_size): def test_boradcast_data(model_parallel_size):
...@@ -88,5 +86,3 @@ if __name__ == '__main__': ...@@ -88,5 +86,3 @@ if __name__ == '__main__':
print_separator('test test boradcast data') print_separator('test test boradcast data')
test_boradcast_data(model_parallel_size) test_boradcast_data(model_parallel_size)
model_parallel_size *= 2 model_parallel_size *= 2
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import print_separator
from commons import initialize_distributed
import mpu
import torch
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_initialize_model_parallel(model_parallel_size): def test_initialize_model_parallel(model_parallel_size):
...@@ -46,7 +44,6 @@ def test_initialize_model_parallel(model_parallel_size): ...@@ -46,7 +44,6 @@ def test_initialize_model_parallel(model_parallel_size):
assert rank == mpu.get_model_parallel_rank() assert rank == mpu.get_model_parallel_rank()
check(mpu.get_model_parallel_group(), world_size, rank) check(mpu.get_model_parallel_group(), world_size, rank)
# Data parallel. # Data parallel.
world_size = torch.distributed.get_world_size() // model_parallel_size_ world_size = torch.distributed.get_world_size() // model_parallel_size_
rank = torch.distributed.get_rank() // model_parallel_size rank = torch.distributed.get_rank() // model_parallel_size
......
...@@ -13,20 +13,18 @@ ...@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mpu import layers
from commons import set_random_seed
from commons import print_separator
from commons import initialize_distributed
import mpu
from torch.nn.parameter import Parameter
import torch.nn.init as init
import torch
import random import random
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import torch.nn.init as init
from torch.nn.parameter import Parameter
import mpu
from commons import initialize_distributed
from commons import print_separator
from commons import set_random_seed
from mpu import layers
def test_parallel_embedding(model_parallel_size): def test_parallel_embedding(model_parallel_size):
...@@ -45,7 +43,7 @@ def test_parallel_embedding(model_parallel_size): ...@@ -45,7 +43,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed(123) set_random_seed(123)
input_data = torch.LongTensor( input_data = torch.LongTensor(
size=(batch_size,seq_length)).random_(0, vocab_size).cuda() size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed) set_random_seed(seed)
...@@ -57,7 +55,7 @@ def test_parallel_embedding(model_parallel_size): ...@@ -57,7 +55,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed(seed) set_random_seed(seed)
embedding_parallel = layers.ParallelEmbedding( embedding_parallel = layers.ParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda() vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_parallel(input_data) output = embedding_parallel(input_data)
loss_parallel = torch.mul(output, loss_weight).sum() loss_parallel = torch.mul(output, loss_weight).sum()
loss_parallel.backward() loss_parallel.backward()
...@@ -176,10 +174,11 @@ def test_initialize_affine_weight(model_parallel_size): ...@@ -176,10 +174,11 @@ def test_initialize_affine_weight(model_parallel_size):
class IdentityLayer2D(torch.nn.Module): class IdentityLayer2D(torch.nn.Module):
def __init__(self, m , n): def __init__(self, m, n):
super(IdentityLayer2D, self).__init__() super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n)) self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight) torch.nn.init.xavier_normal_(self.weight)
def forward(self): def forward(self):
return self.weight return self.weight
...@@ -317,10 +316,11 @@ def test_row_parallel_linear(model_parallel_size): ...@@ -317,10 +316,11 @@ def test_row_parallel_linear(model_parallel_size):
class IdentityLayer3D(torch.nn.Module): class IdentityLayer3D(torch.nn.Module):
def __init__(self, m , n, k): def __init__(self, m, n, k):
super(IdentityLayer3D, self).__init__() super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k)) self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight) torch.nn.init.xavier_normal_(self.weight)
def forward(self): def forward(self):
return self.weight return self.weight
...@@ -335,14 +335,14 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, ...@@ -335,14 +335,14 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
set_random_seed(seed) set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \ num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size() torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads hidden_size = hidden_size_per_att_head * num_att_heads
# Network # Network
identity_layer = IdentityLayer3D(batch_size, sequence_length, identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda() hidden_size).cuda()
attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads, attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
dropout_prob).cuda() dropout_prob).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward # Forward
...@@ -366,17 +366,17 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -366,17 +366,17 @@ def test_parallel_self_attention(model_parallel_size):
num_att_heads_per_partition = 3 num_att_heads_per_partition = 3
hidden_size_per_att_head = 7 hidden_size_per_att_head = 7
dropout_prob = 0.0 # has to be zero dropout_prob = 0.0 # has to be zero
batch_size = 5 batch_size = 5
sequence_length = 13 sequence_length = 13
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \ rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 =parallel_self_attention( attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition, 1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \ rank, hidden_size, model_parallel_size, loss, \
attention_layer, identity_layer =parallel_self_attention( attention_layer, identity_layer = parallel_self_attention(
model_parallel_size, num_att_heads_per_partition, model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size assert hideen_size_1 == hidden_size
...@@ -409,6 +409,7 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -409,6 +409,7 @@ def test_parallel_self_attention(model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)') print(' >> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition, def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length): hidden_size_per_att_head, batch_size, sequence_length):
...@@ -419,7 +420,7 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition, ...@@ -419,7 +420,7 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
set_random_seed(seed) set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \ num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size() torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads hidden_size = hidden_size_per_att_head * num_att_heads
intermediate_size = 4 * hidden_size intermediate_size = 4 * hidden_size
......
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import print_separator
from commons import initialize_distributed
import mpu
import torch
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_set_cuda_rng_state(model_parallel_size): def test_set_cuda_rng_state(model_parallel_size):
...@@ -204,4 +202,3 @@ if __name__ == '__main__': ...@@ -204,4 +202,3 @@ if __name__ == '__main__':
print_separator('test model parallel cuda manual seed') print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(model_parallel_size) test_model_parallel_cuda_manual_seed(model_parallel_size)
model_parallel_size *= 2 model_parallel_size *= 2
...@@ -120,8 +120,8 @@ def generate_samples_input_from_file(model): ...@@ -120,8 +120,8 @@ def generate_samples_input_from_file(model):
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >= (args.seq_length // 2): if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length,
"\nPlease give smaller context (half of the " "\nPlease give smaller context (half of the "
"sequence length)!", flush=True) "sequence length)!", flush=True)
continue continue
else: else:
...@@ -187,8 +187,8 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -187,8 +187,8 @@ def generate_samples_interactive(model, print_frequency=24):
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >= (args.seq_length // 2): if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length,
"\nPlease give smaller context (half of the " "\nPlease give smaller context (half of the "
"sequence length)!", flush=True) "sequence length)!", flush=True)
continue continue
else: else:
...@@ -246,7 +246,7 @@ def generate_samples_unconditional(model): ...@@ -246,7 +246,7 @@ def generate_samples_unconditional(model):
for token_stream in get_token_stream(model, for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)): copy.deepcopy(context_tokens)):
pass pass
if ctr%args.log_interval == 0: if ctr % args.log_interval == 0:
print('Avg s/batch:', print('Avg s/batch:',
(time.time() - start_time) / min(args.log_interval, ctr + 1)) (time.time() - start_time) / min(args.log_interval, ctr + 1))
start_time = time.time() start_time = time.time()
...@@ -254,10 +254,10 @@ def generate_samples_unconditional(model): ...@@ -254,10 +254,10 @@ def generate_samples_unconditional(model):
token_batch = token_stream[0].cpu().numpy().tolist() token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist() length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch): for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length-1] tokens = tokens[1:length - 1]
text = tokenizer.detokenize(tokens) text = tokenizer.detokenize(tokens)
is_finished = length < args.seq_length - 1 is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length-1, 'finished': is_finished} datum = {'text': text, 'length': length - 1, 'finished': is_finished}
yield datum yield datum
ctr += 1 ctr += 1
if ctr >= num_samples: if ctr >= num_samples:
...@@ -272,7 +272,7 @@ def generate_and_write_samples_unconditional(model): ...@@ -272,7 +272,7 @@ def generate_and_write_samples_unconditional(model):
assert args.genfile is not None assert args.genfile is not None
with open(args.genfile, 'w') as f: with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model): for datum in generate_samples_unconditional(model):
f.write(json.dumps(datum)+'\n') f.write(json.dumps(datum) + '\n')
def pad_batch(batch, pad_id, args): def pad_batch(batch, pad_id, args):
...@@ -281,7 +281,7 @@ def pad_batch(batch, pad_id, args): ...@@ -281,7 +281,7 @@ def pad_batch(batch, pad_id, args):
for tokens in batch: for tokens in batch:
context_length = len(tokens) context_length = len(tokens)
if context_length < args.seq_length: if context_length < args.seq_length:
tokens.extend([pad_id]*(args.seq_length - context_length)) tokens.extend([pad_id] * (args.seq_length - context_length))
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
...@@ -345,7 +345,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -345,7 +345,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if maxlen > (org_context_length + args.out_seq_length): if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda()*maxlen lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen): while context_length <= (maxlen):
......
...@@ -25,377 +25,377 @@ import six ...@@ -25,377 +25,377 @@ import six
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name.""" """Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check # The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably # as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so # should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate. # we have to heuristically detect it to validate.
if not init_checkpoint: if not init_checkpoint:
return return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None: if m is None:
return return
model_name = m.group(1) model_name = m.group(1)
lower_models = [ lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
] ]
cased_models = [ cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12" "multi_cased_L-12_H-768_A-12"
] ]
is_bad_config = False is_bad_config = False
if model_name in lower_models and not do_lower_case: if model_name in lower_models and not do_lower_case:
is_bad_config = True is_bad_config = True
actual_flag = "False" actual_flag = "False"
case_name = "lowercased" case_name = "lowercased"
opposite_flag = "True" opposite_flag = "True"
if model_name in cased_models and do_lower_case: if model_name in cased_models and do_lower_case:
is_bad_config = True is_bad_config = True
actual_flag = "True" actual_flag = "True"
case_name = "cased" case_name = "cased"
opposite_flag = "False" opposite_flag = "False"
if is_bad_config: if is_bad_config:
raise ValueError( raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you " "However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches " "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please " "how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint, "just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag)) model_name, case_name, opposite_flag))
def convert_to_unicode(text): def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3: if six.PY3:
if isinstance(text, str): if isinstance(text, str):
return text return text
elif isinstance(text, bytes): elif isinstance(text, bytes):
return text.decode("utf-8", "ignore") return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else: else:
raise ValueError("Unsupported string type: %s" % (type(text))) raise ValueError("Not running on Python2 or Python 3?")
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text): def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`.""" """Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case # These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string. # it's a Unicode string and in the other it's a byte string.
if six.PY3: if six.PY3:
if isinstance(text, str): if isinstance(text, str):
return text return text
elif isinstance(text, bytes): elif isinstance(text, bytes):
return text.decode("utf-8", "ignore") return text.decode("utf-8", "ignore")
else: else:
raise ValueError("Unsupported string type: %s" % (type(text))) raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2: elif six.PY2:
if isinstance(text, str): if isinstance(text, str):
return text return text
elif isinstance(text, unicode): elif isinstance(text, unicode):
return text.encode("utf-8") return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else: else:
raise ValueError("Unsupported string type: %s" % (type(text))) raise ValueError("Not running on Python2 or Python 3?")
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file): def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary.""" """Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict() vocab = collections.OrderedDict()
index = 0 index = 0
with open(vocab_file, "r") as reader: with open(vocab_file, "r") as reader:
while True: while True:
token = convert_to_unicode(reader.readline()) token = convert_to_unicode(reader.readline())
if not token: if not token:
break break
token = token.strip() token = token.strip()
vocab[token] = index vocab[token] = index
index += 1 index += 1
return vocab return vocab
def convert_by_vocab(vocab, items): def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab.""" """Converts a sequence of [tokens|ids] using the vocab."""
output = [] output = []
for item in items: for item in items:
output.append(vocab[item]) output.append(vocab[item])
return output return output
def convert_tokens_to_ids(vocab, tokens): def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens) return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids): def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids) return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text): def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text.""" """Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip() text = text.strip()
if not text: if not text:
return [] return []
tokens = text.split() tokens = text.split()
return tokens return tokens
class FullTokenizer(object): class FullTokenizer(object):
"""Runs end-to-end tokenziation.""" """Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True): def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()} self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text): def tokenize(self, text):
split_tokens = [] split_tokens = []
for token in self.basic_tokenizer.tokenize(text): for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token): for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token) split_tokens.append(sub_token)
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens) return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids): def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids) return convert_by_vocab(self.inv_vocab, ids)
def vocab_size(self): def vocab_size(self):
return len(self.vocab) return len(self.vocab)
class BasicTokenizer(object): class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True): def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer. """Constructs a BasicTokenizer.
Args: Args:
do_lower_case: Whether to lower case the input. do_lower_case: Whether to lower case the input.
""" """
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
def tokenize(self, text): def tokenize(self, text):
"""Tokenizes a piece of text.""" """Tokenizes a piece of text."""
text = convert_to_unicode(text) text = convert_to_unicode(text)
text = self._clean_text(text) text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese # This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't # models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data # matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese # and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese # characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.). # words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text) text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text) orig_tokens = whitespace_tokenize(text)
split_tokens = [] split_tokens = []
for token in orig_tokens: for token in orig_tokens:
if self.do_lower_case: if self.do_lower_case:
token = token.lower() token = token.lower()
token = self._run_strip_accents(token) token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token)) split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens)) output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens return output_tokens
def _run_strip_accents(self, text): def _run_strip_accents(self, text):
"""Strips accents from a piece of text.""" """Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text) text = unicodedata.normalize("NFD", text)
output = [] output = []
for char in text: for char in text:
cat = unicodedata.category(char) cat = unicodedata.category(char)
if cat == "Mn": if cat == "Mn":
continue continue
output.append(char) output.append(char)
return "".join(output) return "".join(output)
def _run_split_on_punc(self, text): def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text.""" """Splits punctuation on a piece of text."""
chars = list(text) chars = list(text)
i = 0 i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True start_new_word = True
else: output = []
if start_new_word: while i < len(chars):
output.append([]) char = chars[i]
start_new_word = False if _is_punctuation(char):
output[-1].append(char) output.append([char])
i += 1 start_new_word = True
else:
return ["".join(x) for x in output] if start_new_word:
output.append([])
def _tokenize_chinese_chars(self, text): start_new_word = False
"""Adds whitespace around any CJK character.""" output[-1].append(char)
output = [] i += 1
for char in text:
cp = ord(char) return ["".join(x) for x in output]
if self._is_chinese_char(cp):
output.append(" ") def _tokenize_chinese_chars(self, text):
output.append(char) """Adds whitespace around any CJK character."""
output.append(" ") output = []
else: for char in text:
output.append(char) cp = ord(char)
return "".join(output) if self._is_chinese_char(cp):
output.append(" ")
def _is_chinese_char(self, cp): output.append(char)
"""Checks whether CP is the codepoint of a CJK character.""" output.append(" ")
# This defines a "chinese character" as anything in the CJK Unicode block: else:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) output.append(char)
# return "".join(output)
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block, def _is_chinese_char(self, cp):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write """Checks whether CP is the codepoint of a CJK character."""
# space-separated words, so they are not treated specially and handled # This defines a "chinese character" as anything in the CJK Unicode block:
# like the all of the other languages. # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
if ((cp >= 0x4E00 and cp <= 0x9FFF) or # #
(cp >= 0x3400 and cp <= 0x4DBF) or # # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
(cp >= 0x20000 and cp <= 0x2A6DF) or # # despite its name. The modern Korean Hangul alphabet is a different block,
(cp >= 0x2A700 and cp <= 0x2B73F) or # # as is Japanese Hiragana and Katakana. Those alphabets are used to write
(cp >= 0x2B740 and cp <= 0x2B81F) or # # space-separated words, so they are not treated specially and handled
(cp >= 0x2B820 and cp <= 0x2CEAF) or # like the all of the other languages.
(cp >= 0xF900 and cp <= 0xFAFF) or # if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): # (cp >= 0x3400 and cp <= 0x4DBF) or #
return True (cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
return False (cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
def _clean_text(self, text): (cp >= 0xF900 and cp <= 0xFAFF) or #
"""Performs invalid character removal and whitespace cleanup on text.""" (cp >= 0x2F800 and cp <= 0x2FA1F)): #
output = [] return True
for char in text:
cp = ord(char) return False
if cp == 0 or cp == 0xfffd or _is_control(char):
continue def _clean_text(self, text):
if _is_whitespace(char): """Performs invalid character removal and whitespace cleanup on text."""
output.append(" ") output = []
else: for char in text:
output.append(char) cp = ord(char)
return "".join(output) if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object): class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation.""" """Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab self.vocab = vocab
self.unk_token = unk_token self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text): def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces. """Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary. using the given vocabulary.
For example: For example:
input = "unaffable" input = "unaffable"
output = ["un", "##aff", "##able"] output = ["un", "##aff", "##able"]
Args: Args:
text: A single token or whitespace separated tokens. This should have text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer. already been passed through `BasicTokenizer.
Returns: Returns:
A list of wordpiece tokens. A list of wordpiece tokens.
""" """
text = convert_to_unicode(text) text = convert_to_unicode(text)
output_tokens = [] output_tokens = []
for token in whitespace_tokenize(text): for token in whitespace_tokenize(text):
chars = list(token) chars = list(token)
if len(chars) > self.max_input_chars_per_word: if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token) output_tokens.append(self.unk_token)
continue continue
is_bad = False is_bad = False
start = 0 start = 0
sub_tokens = [] sub_tokens = []
while start < len(chars): while start < len(chars):
end = len(chars) end = len(chars)
cur_substr = None cur_substr = None
while start < end: while start < end:
substr = "".join(chars[start:end]) substr = "".join(chars[start:end])
if start > 0: if start > 0:
substr = "##" + substr substr = "##" + substr
if substr in self.vocab: if substr in self.vocab:
cur_substr = substr cur_substr = substr
break break
end -= 1 end -= 1
if cur_substr is None: if cur_substr is None:
is_bad = True is_bad = True
break break
sub_tokens.append(cur_substr) sub_tokens.append(cur_substr)
start = end start = end
if is_bad: if is_bad:
output_tokens.append(self.unk_token) output_tokens.append(self.unk_token)
else: else:
output_tokens.extend(sub_tokens) output_tokens.extend(sub_tokens)
return output_tokens return output_tokens
def _is_whitespace(char): def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character.""" """Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them # \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such. # as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r": if char == " " or char == "\t" or char == "\n" or char == "\r":
return True return True
cat = unicodedata.category(char) cat = unicodedata.category(char)
if cat == "Zs": if cat == "Zs":
return True return True
return False return False
def _is_control(char): def _is_control(char):
"""Checks whether `chars` is a control character.""" """Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace # These are technically control characters but we count them as whitespace
# characters. # characters.
if char == "\t" or char == "\n" or char == "\r": if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char): def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character.""" """Checks whether `chars` is a punctuation character."""
cp = ord(char) cp = ord(char)
# We treat all non-letter/number ASCII as punctuation. # We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode # Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for # Punctuation class but we treat them as punctuation anyways, for
# consistency. # consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True return True
cat = unicodedata.category(char) cat = unicodedata.category(char)
if cat.startswith("P"): if cat.startswith("P"):
return True return True
return False return False
...@@ -29,7 +29,8 @@ try: ...@@ -29,7 +29,8 @@ try:
from functools import lru_cache from functools import lru_cache
except ImportError: except ImportError:
# Just a dummy decorator to get the checks to run on python2 # Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. # because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
...@@ -49,6 +50,7 @@ VOCAB_NAME = 'vocab.json' ...@@ -49,6 +50,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt' MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt' SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache() @lru_cache()
def bytes_to_unicode(): def bytes_to_unicode():
""" """
...@@ -61,17 +63,19 @@ def bytes_to_unicode(): ...@@ -61,17 +63,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on. And avoids mapping to whitespace/control characters the bpe code barfs on.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr _chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:] cs = bs[:]
n = 0 n = 0
for b in range(2**8): for b in range(2**8):
if b not in bs: if b not in bs:
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2**8 + n)
n += 1 n += 1
cs = [_chr(n) for n in cs] cs = [_chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
def get_pairs(word): def get_pairs(word):
"""Return set of symbol pairs in a word. """Return set of symbol pairs in a word.
...@@ -84,6 +88,7 @@ def get_pairs(word): ...@@ -84,6 +88,7 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
class GPT2Tokenizer(object): class GPT2Tokenizer(object):
""" """
GPT-2 BPE tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
...@@ -140,23 +145,31 @@ class GPT2Tokenizer(object): ...@@ -140,23 +145,31 @@ class GPT2Tokenizer(object):
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else: else:
special_tokens = kwargs.pop('special_tokens', []) special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode() self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data] bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should haved added re.IGNORECASE so BPE merges can happen for
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") # capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {} self.special_tokens = {}
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
...@@ -174,8 +187,9 @@ class GPT2Tokenizer(object): ...@@ -174,8 +187,9 @@ class GPT2Tokenizer(object):
self.special_tokens = {} self.special_tokens = {}
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
return return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) self.special_tokens = dict((tok, len(self.encoder) + i)
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens)) logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
...@@ -188,7 +202,7 @@ class GPT2Tokenizer(object): ...@@ -188,7 +202,7 @@ class GPT2Tokenizer(object):
return token return token
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -199,12 +213,12 @@ class GPT2Tokenizer(object): ...@@ -199,12 +213,12 @@ class GPT2Tokenizer(object):
j = word.index(first, i) j = word.index(first, i)
new_word.extend(word[i:j]) new_word.extend(word[i:j])
i = j i = j
except: except BaseException:
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -247,7 +261,8 @@ class GPT2Tokenizer(object): ...@@ -247,7 +261,8 @@ class GPT2Tokenizer(object):
logger.warning( logger.warning(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this" " sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len) " sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
) )
return ids return ids
......
...@@ -32,7 +32,7 @@ def build_tokenizer(args): ...@@ -32,7 +32,7 @@ def build_tokenizer(args):
assert args.vocab_file is not None assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase': if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True) lower_case=True)
elif args.tokenizer_type == 'GPT2BPETokenizer': elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
...@@ -53,7 +53,7 @@ def _vocab_size_with_padding(orig_vocab_size, args): ...@@ -53,7 +53,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after = orig_vocab_size after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \ multiple = args.make_vocab_size_divisible_by * \
args.model_parallel_size args.model_parallel_size
while (after % multiple) != 0: while (after % multiple) != 0:
after += 1 after += 1
if args.rank == 0: if args.rank == 0:
...@@ -134,7 +134,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -134,7 +134,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
self.cls_id = self.tokenizer.vocab['[CLS]'] self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]'] self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]'] self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]'] self.mask_id = self.tokenizer.vocab['[MASK]']
@property @property
def vocab_size(self): def vocab_size(self):
...@@ -168,6 +168,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -168,6 +168,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
def mask(self): def mask(self):
return self.mask_id return self.mask_id
class _GPT2BPETokenizer(AbstractTokenizer): class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer.""" """Original GPT2 BPE tokenizer."""
......
...@@ -97,7 +97,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -97,7 +97,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator) train_data_iterator, valid_data_iterator)
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
...@@ -174,7 +173,7 @@ def get_optimizer(model): ...@@ -174,7 +173,7 @@ def get_optimizer(model):
dynamic_loss_scale=args.dynamic_loss_scale, dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_args={ dynamic_loss_args={
'scale_window': args.loss_scale_window, 'scale_window': args.loss_scale_window,
'min_scale':args.min_scale, 'min_scale': args.min_scale,
'delayed_shift': args.hysteresis}) 'delayed_shift': args.hysteresis})
return optimizer return optimizer
...@@ -297,6 +296,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -297,6 +296,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Logging. # Logging.
timers_to_log = [] timers_to_log = []
def add_to_logging(name): def add_to_logging(name):
if name in timers.timers: if name in timers.timers:
timers_to_log.append(name) timers_to_log.append(name)
...@@ -431,7 +431,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -431,7 +431,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Reduce across processes. # Reduce across processes.
for key in loss_dict: for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
loss_dict[key] loss_dict[key]
# Move model back to the train mode. # Move model back to the train mode.
model.train() model.train()
...@@ -521,14 +521,14 @@ def build_train_valid_test_data_iterators( ...@@ -521,14 +521,14 @@ def build_train_valid_test_data_iterators(
# Shift the start iterations. # Shift the start iterations.
if train_dataloader is not None: if train_dataloader is not None:
train_dataloader.batch_sampler.start_iter = args.iteration % \ train_dataloader.batch_sampler.start_iter = args.iteration % \
len(train_dataloader) len(train_dataloader)
print_rank_0('setting training data start iteration to {}'. print_rank_0('setting training data start iteration to {}'.
format(train_dataloader.batch_sampler.start_iter)) format(train_dataloader.batch_sampler.start_iter))
if valid_dataloader is not None: if valid_dataloader is not None:
start_iter_val = (args.iteration // args.eval_interval) * \ start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters args.eval_iters
valid_dataloader.batch_sampler.start_iter = start_iter_val % \ valid_dataloader.batch_sampler.start_iter = start_iter_val % \
len(valid_dataloader) len(valid_dataloader)
print_rank_0('setting validation data start iteration to {}'. print_rank_0('setting validation data start iteration to {}'.
format(valid_dataloader.batch_sampler.start_iter)) format(valid_dataloader.batch_sampler.start_iter))
......
...@@ -48,7 +48,7 @@ def report_memory(name): ...@@ -48,7 +48,7 @@ def report_memory(name):
torch.cuda.max_memory_allocated() / mega_bytes) torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format( string += ' | max cached: {}'.format(
torch.cuda.max_memory_cached()/ mega_bytes) torch.cuda.max_memory_cached() / mega_bytes)
print_rank_0(string) print_rank_0(string)
...@@ -164,10 +164,10 @@ def get_ltor_masks_and_position_ids(data, ...@@ -164,10 +164,10 @@ def get_ltor_masks_and_position_ids(data,
i = eod_index[j] i = eod_index[j]
# Mask attention loss. # Mask attention loss.
if reset_attention_mask: if reset_attention_mask:
attention_mask[b, 0, (i+1):, :(i+1)] = 0 attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions. # Reset positions.
if reset_position_ids: if reset_position_ids:
position_ids[b, (i+1):] -= (i + 1 - prev_index) position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1 prev_index = i + 1
# Convert # Convert
......
...@@ -75,8 +75,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length, ...@@ -75,8 +75,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
# A. # A.
len_text_a = len(text_a_ids) len_text_a = len(text_a_ids)
ids.extend(text_a_ids) ids.extend(text_a_ids)
types.extend([0]*len_text_a) types.extend([0] * len_text_a)
paddings.extend([1]*len_text_a) paddings.extend([1] * len_text_a)
# [SEP]. # [SEP].
ids.append(sep_id) ids.append(sep_id)
...@@ -87,8 +87,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length, ...@@ -87,8 +87,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
if text_b_ids is not None: if text_b_ids is not None:
len_text_b = len(text_b_ids) len_text_b = len(text_b_ids)
ids.extend(text_b_ids) ids.extend(text_b_ids)
types.extend([1]*len_text_b) types.extend([1] * len_text_b)
paddings.extend([1]*len_text_b) paddings.extend([1] * len_text_b)
# Cap the size. # Cap the size.
trimmed = False trimmed = False
...@@ -111,8 +111,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length, ...@@ -111,8 +111,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
# Padding. # Padding.
padding_length = max_seq_length - len(ids) padding_length = max_seq_length - len(ids)
if padding_length > 0: if padding_length > 0:
ids.extend([pad_id]*padding_length) ids.extend([pad_id] * padding_length)
types.extend([pad_id]*padding_length) types.extend([pad_id] * padding_length)
paddings.extend([0]*padding_length) paddings.extend([0] * padding_length)
return ids, types, paddings return ids, types, paddings
...@@ -5,6 +5,7 @@ import collections ...@@ -5,6 +5,7 @@ import collections
import numpy as np import numpy as np
import torch import torch
def process_files(args): def process_files(args):
all_predictions = collections.OrderedDict() all_predictions = collections.OrderedDict()
all_labels = collections.OrderedDict() all_labels = collections.OrderedDict()
...@@ -40,12 +41,12 @@ def get_threshold(all_predictions, all_labels, one_threshold=False): ...@@ -40,12 +41,12 @@ def get_threshold(all_predictions, all_labels, one_threshold=False):
for dataset in all_predictions: for dataset in all_predictions:
preds = all_predictions[dataset] preds = all_predictions[dataset]
labels = all_labels[dataset] labels = all_labels[dataset]
out_thresh.append(calc_threshold(preds,labels)) out_thresh.append(calc_threshold(preds, labels))
return out_thresh return out_thresh
def calc_threshold(p, l): def calc_threshold(p, l):
trials = [(i)*(1./100.) for i in range(100)] trials = [(i) * (1. / 100.) for i in range(100)]
best_acc = float('-inf') best_acc = float('-inf')
best_thresh = 0 best_thresh = 0
for t in trials: for t in trials:
...@@ -58,7 +59,7 @@ def calc_threshold(p, l): ...@@ -58,7 +59,7 @@ def calc_threshold(p, l):
def apply_threshold(preds, t): def apply_threshold(preds, t):
assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0]))) assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
prob = preds[:,-1] prob = preds[:, -1]
thresholded = (prob >= t).astype(int) thresholded = (prob >= t).astype(int)
preds = np.zeros_like(preds) preds = np.zeros_like(preds)
preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1 preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
...@@ -66,8 +67,8 @@ def apply_threshold(preds, t): ...@@ -66,8 +67,8 @@ def apply_threshold(preds, t):
def threshold_predictions(all_predictions, threshold): def threshold_predictions(all_predictions, threshold):
if len(threshold)!=len(all_predictions): if len(threshold) != len(all_predictions):
threshold = [threshold[-1]]*(len(all_predictions)-len(threshold)) threshold = [threshold[-1]] * (len(all_predictions) - len(threshold))
for i, dataset in enumerate(all_predictions): for i, dataset in enumerate(all_predictions):
thresh = threshold[i] thresh = threshold[i]
preds = all_predictions[dataset] preds = all_predictions[dataset]
...@@ -77,7 +78,7 @@ def threshold_predictions(all_predictions, threshold): ...@@ -77,7 +78,7 @@ def threshold_predictions(all_predictions, threshold):
def postprocess_predictions(all_predictions, all_labels, args): def postprocess_predictions(all_predictions, all_labels, args):
for d in all_predictions: for d in all_predictions:
all_predictions[d] = all_predictions[d]/len(args.paths) all_predictions[d] = all_predictions[d] / len(args.paths)
if args.calc_threshold: if args.calc_threshold:
args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold) args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
...@@ -98,19 +99,22 @@ def write_predictions(all_predictions, all_labels, all_uid, args): ...@@ -98,19 +99,22 @@ def write_predictions(all_predictions, all_labels, all_uid, args):
if args.eval: if args.eval:
correct = (preds == all_labels[dataset]).sum() correct = (preds == all_labels[dataset]).sum()
num = len(all_labels[dataset]) num = len(all_labels[dataset])
accuracy = correct/num accuracy = correct / num
count += num count += num
all_correct += correct all_correct += correct
accuracy = (preds == all_labels[dataset]).mean() accuracy = (preds == all_labels[dataset]).mean()
print(accuracy) print(accuracy)
if not os.path.exists(os.path.join(args.outdir, dataset)): if not os.path.exists(os.path.join(args.outdir, dataset)):
os.makedirs(os.path.join(args.outdir, dataset)) os.makedirs(os.path.join(args.outdir, dataset))
outpath = os.path.join(args.outdir, dataset, os.path.splitext(args.prediction_name)[0]+'.tsv') outpath = os.path.join(
args.outdir, dataset, os.path.splitext(
args.prediction_name)[0] + '.tsv')
with open(outpath, 'w') as f: with open(outpath, 'w') as f:
f.write('id\tlabel\n') f.write('id\tlabel\n')
f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist()))) f.write('\n'.join(str(uid) + '\t' + str(args.labels[p])
for uid, p in zip(all_uid[dataset], preds.tolist())))
if args.eval: if args.eval:
print(all_correct/count) print(all_correct / count)
def ensemble_predictions(args): def ensemble_predictions(args):
...@@ -119,7 +123,7 @@ def ensemble_predictions(args): ...@@ -119,7 +123,7 @@ def ensemble_predictions(args):
write_predictions(all_predictions, all_labels, all_uid, args) write_predictions(all_predictions, all_labels, all_uid, args)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--paths', required=True, nargs='+', parser.add_argument('--paths', required=True, nargs='+',
help='paths to checkpoint directories used in ensemble') help='paths to checkpoint directories used in ensemble')
...@@ -135,11 +139,11 @@ def main(): ...@@ -135,11 +139,11 @@ def main():
help='use on threshold for all subdatasets') help='use on threshold for all subdatasets')
parser.add_argument('--threshold', nargs='+', default=None, type=float, parser.add_argument('--threshold', nargs='+', default=None, type=float,
help='user supplied threshold for classification') help='user supplied threshold for classification')
parser.add_argument('--labels',nargs='+', default=None, parser.add_argument('--labels', nargs='+', default=None,
help='whitespace separated list of label names') help='whitespace separated list of label names')
args = parser.parse_args() args = parser.parse_args()
ensemble_predictions(args) ensemble_predictions(args)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
...@@ -21,7 +21,7 @@ from megatron import get_args ...@@ -21,7 +21,7 @@ from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer from megatron.training import setup_model_and_optimizer
...@@ -53,7 +53,7 @@ def _cross_entropy_forward_step(batch, model): ...@@ -53,7 +53,7 @@ def _cross_entropy_forward_step(batch, model):
timers('batch generator').start() timers('batch generator').start()
try: try:
batch_ = next(batch) batch_ = next(batch)
except: except BaseException:
batch_ = batch batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_) tokens, types, labels, attention_mask = process_batch(batch_)
timers('batch generator').stop() timers('batch generator').stop()
...@@ -146,7 +146,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -146,7 +146,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# For each remaining epoch # For each remaining epoch
timers('interval time').start() timers('interval time').start()
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
print_rank_0('working on epoch {} ...'.format(epoch+1)) print_rank_0('working on epoch {} ...'.format(epoch + 1))
# Set the data loader epoch to shuffle the index iterator. # Set the data loader epoch to shuffle the index iterator.
train_dataloader.sampler.set_epoch(args.seed + epoch) train_dataloader.sampler.set_epoch(args.seed + epoch)
...@@ -172,7 +172,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -172,7 +172,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
report_memory_flag) report_memory_flag)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler) optimizer, lr_scheduler)
......
...@@ -48,11 +48,9 @@ class GLUEAbstractDataset(ABC, Dataset): ...@@ -48,11 +48,9 @@ class GLUEAbstractDataset(ABC, Dataset):
print_rank_0(' >> total number of samples: {}'.format( print_rank_0(' >> total number of samples: {}'.format(
len(self.samples))) len(self.samples)))
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def __getitem__(self, idx): def __getitem__(self, idx):
raw_sample = self.samples[idx] raw_sample = self.samples[idx]
ids, types, paddings = build_tokens_types_paddings_from_text( ids, types, paddings = build_tokens_types_paddings_from_text(
...@@ -62,7 +60,6 @@ class GLUEAbstractDataset(ABC, Dataset): ...@@ -62,7 +60,6 @@ class GLUEAbstractDataset(ABC, Dataset):
raw_sample['label'], raw_sample['uid']) raw_sample['label'], raw_sample['uid'])
return sample return sample
@abstractmethod @abstractmethod
def process_samples_from_single_path(self, datapath): def process_samples_from_single_path(self, datapath):
"""Abstract method that takes a single path / filename and """Abstract method that takes a single path / filename and
......
...@@ -38,7 +38,6 @@ def glue_classification(num_classes, Dataset, ...@@ -38,7 +38,6 @@ def glue_classification(num_classes, Dataset,
return train_dataset, valid_dataset return train_dataset, valid_dataset
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
...@@ -48,7 +47,6 @@ def glue_classification(num_classes, Dataset, ...@@ -48,7 +47,6 @@ def glue_classification(num_classes, Dataset,
return Classification(num_classes=num_classes, num_tokentypes=2) return Classification(num_classes=num_classes, num_tokentypes=2)
def metrics_func_provider(): def metrics_func_provider():
"""Privde metrics callback function.""" """Privde metrics callback function."""
def single_dataset_provider(datapath): def single_dataset_provider(datapath):
...@@ -59,7 +57,6 @@ def glue_classification(num_classes, Dataset, ...@@ -59,7 +57,6 @@ def glue_classification(num_classes, Dataset,
return Dataset(name, [datapath], tokenizer, args.seq_length) return Dataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(single_dataset_provider) return accuracy_func_provider(single_dataset_provider)
"""Finetune/evaluate.""" """Finetune/evaluate."""
finetune(train_valid_datasets_provider, model_provider, finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider) end_of_epoch_callback_provider=metrics_func_provider)
...@@ -72,6 +69,7 @@ def main(): ...@@ -72,6 +69,7 @@ def main():
num_classes = 3 num_classes = 3
from tasks.glue.mnli import MNLIDataset as Dataset from tasks.glue.mnli import MNLIDataset as Dataset
def name_from_datapath(datapath): def name_from_datapath(datapath):
return datapath.split('MNLI')[-1].strip( return datapath.split('MNLI')[-1].strip(
'.tsv').strip('/').replace('_', '-') '.tsv').strip('/').replace('_', '-')
...@@ -80,6 +78,7 @@ def main(): ...@@ -80,6 +78,7 @@ def main():
num_classes = 2 num_classes = 2
from tasks.glue.qqp import QQPDataset as Dataset from tasks.glue.qqp import QQPDataset as Dataset
def name_from_datapath(datapath): def name_from_datapath(datapath):
return datapath.split('QQP')[-1].strip( return datapath.split('QQP')[-1].strip(
'.tsv').strip('/').replace('_', '-') '.tsv').strip('/').replace('_', '-')
......
...@@ -31,7 +31,6 @@ class MNLIDataset(GLUEAbstractDataset): ...@@ -31,7 +31,6 @@ class MNLIDataset(GLUEAbstractDataset):
super().__init__('MNLI', name, datapaths, super().__init__('MNLI', name, datapaths,
tokenizer, max_seq_length) tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename): def process_samples_from_single_path(self, filename):
""""Implement abstract method.""" """"Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename)) print_rank_0(' > Processing {} ...'.format(filename))
......
...@@ -31,7 +31,6 @@ class QQPDataset(GLUEAbstractDataset): ...@@ -31,7 +31,6 @@ class QQPDataset(GLUEAbstractDataset):
super().__init__('QQP', name, datapaths, super().__init__('QQP', name, datapaths,
tokenizer, max_seq_length) tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename): def process_samples_from_single_path(self, filename):
""""Implement abstract method.""" """"Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename)) print_rank_0(' > Processing {} ...'.format(filename))
......
...@@ -15,14 +15,13 @@ ...@@ -15,14 +15,13 @@
"""Main tasks functionality.""" """Main tasks functionality."""
from megatron.initialize import initialize_megatron
from megatron import get_args
import os import os
import sys import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir))) os.path.pardir)))
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser): def get_tasks_args(parser):
"""Provide extra arguments required for tasks.""" """Provide extra arguments required for tasks."""
...@@ -46,7 +45,7 @@ def get_tasks_args(parser): ...@@ -46,7 +45,7 @@ def get_tasks_args(parser):
group.add_argument('--overlapping-eval', type=int, default=32, group.add_argument('--overlapping-eval', type=int, default=32,
help='Sliding window for overlapping evaluation.') help='Sliding window for overlapping evaluation.')
group.add_argument('--strict-lambada', action='store_true', group.add_argument('--strict-lambada', action='store_true',
help='Use more difficult formulation of lambada.') help='Use more difficult formulation of lambada.')
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment