".github/vscode:/vscode.git/clone" did not exist on "9b7cf9ee6c299e85e3273842ee2b007312f9276d"
Commit 5e56e563 authored by Neel Kant's avatar Neel Kant
Browse files

Merge master into realm-mlm

parents 6c0a5bd8 569b3dab
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group()) group=get_model_parallel_group())
# Loss = log(sum(exp(logits))) - predicted-logit. # Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits loss = torch.log(sum_exp_logits) - predicted_logits
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,10 +21,52 @@ ...@@ -21,10 +21,52 @@
import torch import torch
from torch._six import inf from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank
def l2_grad_clipper(parameters, max_norm):
"""Efficient L2 norm gradient clipping."""
overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
# Make sure we have an iterable.
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters with gradients.
parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations.
mp_rank_is_zero = (get_model_parallel_rank() == 0)
parameters_for_norm = list(filter(
lambda p: p.model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm.
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
overflow_buf,
[parameters_for_norm],
False # no per-parameter norm
)
# Sum across all model parallel GPUs.
norm_2 = norm * norm
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = norm_2.item() ** 0.5
# Scale to get max_norm.
clip_coef = float(max_norm) / (total_norm + 1.0e-6)
grads = [p.grad for p in parameters_with_grads]
if clip_coef < 1.0:
multi_tensor_applier(
amp_C.multi_tensor_scale,
overflow_buf,
[grads, grads],
clip_coef)
return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2): def clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters. """Clips gradient norm of an iterable of parameters.
...@@ -55,6 +97,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -55,6 +97,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group()) group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
#elif norm_type == 2:
# total_norm = l2_grad_clipper(parameters, max_norm)
else: else:
total_norm = 0 total_norm = 0
for p in parameters: for p in parameters:
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
init_method: method to initialize weights. init_method: method to initialize weights.
""" """
def __init__(self, num_embeddings, embedding_dim, def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_): init_method=init.xavier_normal_):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
...@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module): ...@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
init_method: method to initialize weights. init_method: method to initialize weights.
""" """
def __init__(self, num_embeddings, embedding_dim, def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_, init_method=init.xavier_normal_,
keep_master_weight_for_test=False): keep_master_weight_for_test=False):
...@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module):
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
""" """
def __init__(self, input_size, output_size, bias=True, gather_output=True, def __init__(self, input_size, output_size, bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False): keep_master_weight_for_test=False):
...@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module):
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
""" """
def __init__(self, input_size, output_size, bias=True, def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False, input_is_parallel=False,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
...@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module): ...@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module):
else: else:
output = output_ output = output_
return output return output
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): ...@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
def copy_to_model_parallel_region(input_): def copy_to_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_) return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_): def reduce_from_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_) return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_): def scatter_to_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_) return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_): def gather_from_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_) return _GatherFromModelParallelRegion.apply(input_)
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -73,6 +73,7 @@ class CudaRNGStatesTracker: ...@@ -73,6 +73,7 @@ class CudaRNGStatesTracker:
rng state, we can perform operations and return to our starting rng state, we can perform operations and return to our starting
cuda state. cuda state.
""" """
def __init__(self): def __init__(self):
# Map from a string name to the cuda rng state. # Map from a string name to the cuda rng state.
self.states_ = {} self.states_ = {}
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -26,6 +26,7 @@ class IdentityLayer(torch.nn.Module): ...@@ -26,6 +26,7 @@ class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0): def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__() super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size)) self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self): def forward(self):
return self.weight return self.weight
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,20 +13,18 @@ ...@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import set_random_seed
from commons import IdentityLayer
from commons import print_separator
from commons import initialize_distributed
from mpu.cross_entropy import vocab_parallel_cross_entropy
import mpu
import torch.nn.functional as F
import torch
import random import random
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import torch.nn.functional as F
import mpu
from mpu.cross_entropy import vocab_parallel_cross_entropy
from commons import initialize_distributed
from commons import print_separator
from commons import IdentityLayer
from commons import set_random_seed
def torch_cross_entropy(batch_size, seq_length, vocab_size, def torch_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed): logits_scale, seed):
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,18 +13,16 @@ ...@@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import print_separator
from commons import initialize_distributed
from mpu import data as data_utils
import mpu
import torch
import functools import functools
import operator import operator
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import mpu
from mpu import data as data_utils
from commons import initialize_distributed
from commons import print_separator
def test_boradcast_data(model_parallel_size): def test_boradcast_data(model_parallel_size):
...@@ -88,5 +86,3 @@ if __name__ == '__main__': ...@@ -88,5 +86,3 @@ if __name__ == '__main__':
print_separator('test test boradcast data') print_separator('test test boradcast data')
test_boradcast_data(model_parallel_size) test_boradcast_data(model_parallel_size)
model_parallel_size *= 2 model_parallel_size *= 2
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import print_separator
from commons import initialize_distributed
import mpu
import torch
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_initialize_model_parallel(model_parallel_size): def test_initialize_model_parallel(model_parallel_size):
...@@ -46,7 +44,6 @@ def test_initialize_model_parallel(model_parallel_size): ...@@ -46,7 +44,6 @@ def test_initialize_model_parallel(model_parallel_size):
assert rank == mpu.get_model_parallel_rank() assert rank == mpu.get_model_parallel_rank()
check(mpu.get_model_parallel_group(), world_size, rank) check(mpu.get_model_parallel_group(), world_size, rank)
# Data parallel. # Data parallel.
world_size = torch.distributed.get_world_size() // model_parallel_size_ world_size = torch.distributed.get_world_size() // model_parallel_size_
rank = torch.distributed.get_rank() // model_parallel_size rank = torch.distributed.get_rank() // model_parallel_size
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,20 +13,18 @@ ...@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mpu import layers
from commons import set_random_seed
from commons import print_separator
from commons import initialize_distributed
import mpu
from torch.nn.parameter import Parameter
import torch.nn.init as init
import torch
import random import random
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import torch.nn.init as init
from torch.nn.parameter import Parameter
import mpu
from commons import initialize_distributed
from commons import print_separator
from commons import set_random_seed
from mpu import layers
def test_parallel_embedding(model_parallel_size): def test_parallel_embedding(model_parallel_size):
...@@ -45,7 +43,7 @@ def test_parallel_embedding(model_parallel_size): ...@@ -45,7 +43,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed(123) set_random_seed(123)
input_data = torch.LongTensor( input_data = torch.LongTensor(
size=(batch_size,seq_length)).random_(0, vocab_size).cuda() size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed) set_random_seed(seed)
...@@ -176,10 +174,11 @@ def test_initialize_affine_weight(model_parallel_size): ...@@ -176,10 +174,11 @@ def test_initialize_affine_weight(model_parallel_size):
class IdentityLayer2D(torch.nn.Module): class IdentityLayer2D(torch.nn.Module):
def __init__(self, m , n): def __init__(self, m, n):
super(IdentityLayer2D, self).__init__() super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n)) self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight) torch.nn.init.xavier_normal_(self.weight)
def forward(self): def forward(self):
return self.weight return self.weight
...@@ -317,10 +316,11 @@ def test_row_parallel_linear(model_parallel_size): ...@@ -317,10 +316,11 @@ def test_row_parallel_linear(model_parallel_size):
class IdentityLayer3D(torch.nn.Module): class IdentityLayer3D(torch.nn.Module):
def __init__(self, m , n, k): def __init__(self, m, n, k):
super(IdentityLayer3D, self).__init__() super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k)) self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight) torch.nn.init.xavier_normal_(self.weight)
def forward(self): def forward(self):
return self.weight return self.weight
...@@ -371,12 +371,12 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -371,12 +371,12 @@ def test_parallel_self_attention(model_parallel_size):
sequence_length = 13 sequence_length = 13
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \ rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 =parallel_self_attention( attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition, 1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \ rank, hidden_size, model_parallel_size, loss, \
attention_layer, identity_layer =parallel_self_attention( attention_layer, identity_layer = parallel_self_attention(
model_parallel_size, num_att_heads_per_partition, model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size assert hideen_size_1 == hidden_size
...@@ -409,6 +409,7 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -409,6 +409,7 @@ def test_parallel_self_attention(model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)') print(' >> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition, def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length): hidden_size_per_att_head, batch_size, sequence_length):
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import print_separator
from commons import initialize_distributed
import mpu
import torch
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_set_cuda_rng_state(model_parallel_size): def test_set_cuda_rng_state(model_parallel_size):
...@@ -204,4 +202,3 @@ if __name__ == '__main__': ...@@ -204,4 +202,3 @@ if __name__ == '__main__':
print_separator('test model parallel cuda manual seed') print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(model_parallel_size) test_model_parallel_cuda_manual_seed(model_parallel_size)
model_parallel_size *= 2 model_parallel_size *= 2
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -42,8 +42,7 @@ def get_batch(context_tokens): ...@@ -42,8 +42,7 @@ def get_batch(context_tokens):
tokenizer.eod, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
args.eod_mask_loss, args.eod_mask_loss)
args.fp16)
return tokens, attention_mask, position_ids return tokens, attention_mask, position_ids
...@@ -120,7 +119,7 @@ def generate_samples_input_from_file(model): ...@@ -120,7 +119,7 @@ def generate_samples_input_from_file(model):
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >= (args.seq_length // 2): if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length,
"\nPlease give smaller context (half of the " "\nPlease give smaller context (half of the "
"sequence length)!", flush=True) "sequence length)!", flush=True)
continue continue
...@@ -187,7 +186,7 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -187,7 +186,7 @@ def generate_samples_interactive(model, print_frequency=24):
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >= (args.seq_length // 2): if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length,
"\nPlease give smaller context (half of the " "\nPlease give smaller context (half of the "
"sequence length)!", flush=True) "sequence length)!", flush=True)
continue continue
...@@ -246,7 +245,7 @@ def generate_samples_unconditional(model): ...@@ -246,7 +245,7 @@ def generate_samples_unconditional(model):
for token_stream in get_token_stream(model, for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)): copy.deepcopy(context_tokens)):
pass pass
if ctr%args.log_interval == 0: if ctr % args.log_interval == 0:
print('Avg s/batch:', print('Avg s/batch:',
(time.time() - start_time) / min(args.log_interval, ctr + 1)) (time.time() - start_time) / min(args.log_interval, ctr + 1))
start_time = time.time() start_time = time.time()
...@@ -254,10 +253,10 @@ def generate_samples_unconditional(model): ...@@ -254,10 +253,10 @@ def generate_samples_unconditional(model):
token_batch = token_stream[0].cpu().numpy().tolist() token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist() length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch): for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length-1] tokens = tokens[1:length - 1]
text = tokenizer.detokenize(tokens) text = tokenizer.detokenize(tokens)
is_finished = length < args.seq_length - 1 is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length-1, 'finished': is_finished} datum = {'text': text, 'length': length - 1, 'finished': is_finished}
yield datum yield datum
ctr += 1 ctr += 1
if ctr >= num_samples: if ctr >= num_samples:
...@@ -272,7 +271,7 @@ def generate_and_write_samples_unconditional(model): ...@@ -272,7 +271,7 @@ def generate_and_write_samples_unconditional(model):
assert args.genfile is not None assert args.genfile is not None
with open(args.genfile, 'w') as f: with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model): for datum in generate_samples_unconditional(model):
f.write(json.dumps(datum)+'\n') f.write(json.dumps(datum) + '\n')
def pad_batch(batch, pad_id, args): def pad_batch(batch, pad_id, args):
...@@ -281,7 +280,7 @@ def pad_batch(batch, pad_id, args): ...@@ -281,7 +280,7 @@ def pad_batch(batch, pad_id, args):
for tokens in batch: for tokens in batch:
context_length = len(tokens) context_length = len(tokens)
if context_length < args.seq_length: if context_length < args.seq_length:
tokens.extend([pad_id]*(args.seq_length - context_length)) tokens.extend([pad_id] * (args.seq_length - context_length))
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
...@@ -345,7 +344,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -345,7 +344,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if maxlen > (org_context_length + args.out_seq_length): if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda()*maxlen lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen): while context_length <= (maxlen):
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes.""" """Tokenization classes."""
from __future__ import absolute_import from __future__ import absolute_import
......
...@@ -29,7 +29,8 @@ try: ...@@ -29,7 +29,8 @@ try:
from functools import lru_cache from functools import lru_cache
except ImportError: except ImportError:
# Just a dummy decorator to get the checks to run on python2 # Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. # because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
...@@ -49,6 +50,7 @@ VOCAB_NAME = 'vocab.json' ...@@ -49,6 +50,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt' MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt' SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache() @lru_cache()
def bytes_to_unicode(): def bytes_to_unicode():
""" """
...@@ -61,17 +63,19 @@ def bytes_to_unicode(): ...@@ -61,17 +63,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on. And avoids mapping to whitespace/control characters the bpe code barfs on.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr _chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:] cs = bs[:]
n = 0 n = 0
for b in range(2**8): for b in range(2**8):
if b not in bs: if b not in bs:
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2**8 + n)
n += 1 n += 1
cs = [_chr(n) for n in cs] cs = [_chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
def get_pairs(word): def get_pairs(word):
"""Return set of symbol pairs in a word. """Return set of symbol pairs in a word.
...@@ -84,6 +88,7 @@ def get_pairs(word): ...@@ -84,6 +88,7 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
class GPT2Tokenizer(object): class GPT2Tokenizer(object):
""" """
GPT-2 BPE tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
...@@ -140,23 +145,31 @@ class GPT2Tokenizer(object): ...@@ -140,23 +145,31 @@ class GPT2Tokenizer(object):
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else: else:
special_tokens = kwargs.pop('special_tokens', []) special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode() self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data] bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should haved added re.IGNORECASE so BPE merges can happen for
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") # capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {} self.special_tokens = {}
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
...@@ -174,8 +187,9 @@ class GPT2Tokenizer(object): ...@@ -174,8 +187,9 @@ class GPT2Tokenizer(object):
self.special_tokens = {} self.special_tokens = {}
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
return return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) self.special_tokens = dict((tok, len(self.encoder) + i)
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens)) logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
...@@ -188,7 +202,7 @@ class GPT2Tokenizer(object): ...@@ -188,7 +202,7 @@ class GPT2Tokenizer(object):
return token return token
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -199,12 +213,12 @@ class GPT2Tokenizer(object): ...@@ -199,12 +213,12 @@ class GPT2Tokenizer(object):
j = word.index(first, i) j = word.index(first, i)
new_word.extend(word[i:j]) new_word.extend(word[i:j])
i = j i = j
except: except BaseException:
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -247,7 +261,8 @@ class GPT2Tokenizer(object): ...@@ -247,7 +261,8 @@ class GPT2Tokenizer(object):
logger.warning( logger.warning(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this" " sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len) " sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
) )
return ids return ids
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -33,6 +33,9 @@ def build_tokenizer(args): ...@@ -33,6 +33,9 @@ def build_tokenizer(args):
if args.tokenizer_type == 'BertWordPieceLowerCase': if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True) lower_case=True)
elif args.tokenizer_type == 'BertWordPieceCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False)
elif args.tokenizer_type == 'GPT2BPETokenizer': elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
...@@ -168,6 +171,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -168,6 +171,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
def mask(self): def mask(self):
return self.mask_id return self.mask_id
class _GPT2BPETokenizer(AbstractTokenizer): class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer.""" """Original GPT2 BPE tokenizer."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment