Commit 230156c4 authored by yangzhong's avatar yangzhong
Browse files

bert-large training

parents
Pipeline #3006 failed with stages
in 0 seconds
""" Official evaluation script for v1.1 of the SQuAD dataset. """
from __future__ import print_function
from collections import Counter
import string
import re
import argparse
import json
import sys
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(dataset, predictions):
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article['paragraphs']:
for qa in paragraph['qas']:
total += 1
if qa['id'] not in predictions:
message = 'Unanswered question ' + qa['id'] + \
' will receive score 0.'
print(message, file=sys.stderr)
continue
ground_truths = list(map(lambda x: x['text'], qa['answers']))
prediction = predictions[qa['id']]
exact_match += metric_max_over_ground_truths(
exact_match_score, prediction, ground_truths)
f1 += metric_max_over_ground_truths(
f1_score, prediction, ground_truths)
exact_match = 100.0 * exact_match / total
f1 = 100.0 * f1 / total
return {'exact_match': exact_match, 'f1': f1}
if __name__ == '__main__':
expected_version = '1.1'
parser = argparse.ArgumentParser(
description='Evaluation for SQuAD ' + expected_version)
parser.add_argument('dataset_file', help='Dataset file')
parser.add_argument('prediction_file', help='Prediction File')
args = parser.parse_args()
with open(args.dataset_file) as dataset_file:
dataset_json = json.load(dataset_file)
if (dataset_json['version'] != expected_version):
print('Evaluation expects v-' + expected_version +
', but got dataset with v-' + dataset_json['version'],
file=sys.stderr)
dataset = dataset_json['data']
with open(args.prediction_file) as prediction_file:
predictions = json.load(prediction_file)
print(json.dumps(evaluate(dataset, predictions)))
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extract pre-computed feature vectors from a PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import logging
import json
import re
import torch
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tokenization import BertTokenizer
from modeling import BertModel
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
class InputExample(object):
def __init__(self, unique_id, text_a, text_b):
self.unique_id = unique_id
self.text_a = text_a
self.text_b = text_b
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
self.unique_id = unique_id
self.tokens = tokens
self.input_ids = input_ids
self.input_mask = input_mask
self.input_type_ids = input_type_ids
def convert_examples_to_features(examples, seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""
features = []
for (ex_index, example) in enumerate(examples):
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > seq_length - 2:
tokens_a = tokens_a[0:(seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambigiously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
input_type_ids = []
tokens.append("[CLS]")
input_type_ids.append(0)
for token in tokens_a:
tokens.append(token)
input_type_ids.append(0)
tokens.append("[SEP]")
input_type_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
input_type_ids.append(1)
tokens.append("[SEP]")
input_type_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < seq_length:
input_ids.append(0)
input_mask.append(0)
input_type_ids.append(0)
assert len(input_ids) == seq_length
assert len(input_mask) == seq_length
assert len(input_type_ids) == seq_length
if ex_index < 5:
logger.info("*** Example ***")
logger.info("unique_id: %s" % (example.unique_id))
logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info(
"input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
features.append(
InputFeatures(
unique_id=example.unique_id,
tokens=tokens,
input_ids=input_ids,
input_mask=input_mask,
input_type_ids=input_type_ids))
return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def read_examples(input_file):
"""Read a list of `InputExample`s from an input file."""
examples = []
unique_id = 0
with open(input_file, "r", encoding='utf-8') as reader:
while True:
line = reader.readline()
if not line:
break
line = line.strip()
text_a = None
text_b = None
m = re.match(r"^(.*) \|\|\| (.*)$", line)
if m is None:
text_a = line
else:
text_a = m.group(1)
text_b = m.group(2)
examples.append(
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
unique_id += 1
return examples
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--input_file", default=None, type=str, required=True)
parser.add_argument("--output_file", default=None, type=str, required=True)
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
## Other parameters
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
"than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
parser.add_argument("--local_rank",
type=int,
default=-1,
help = "local_rank for distributed training on gpus")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
args = parser.parse_args()
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1)))
layer_indexes = [int(x) for x in args.layers.split(",")]
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
examples = read_examples(args.input_file)
features = convert_examples_to_features(
examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer)
unique_id_to_feature = {}
for feature in features:
unique_id_to_feature[feature.unique_id] = feature
model = BertModel.from_pretrained(args.bert_model)
model.to(device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
model.eval()
with open(args.output_file, "w", encoding='utf-8') as writer:
for input_ids, input_mask, example_indices in eval_dataloader:
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
all_encoder_layers = all_encoder_layers
for b, example_index in enumerate(example_indices):
feature = features[example_index.item()]
unique_id = int(feature.unique_id)
# feature = unique_id_to_feature[unique_id]
output_json = collections.OrderedDict()
output_json["linex_index"] = unique_id
all_out_features = []
for (i, token) in enumerate(feature.tokens):
all_layers = []
for (j, layer_index) in enumerate(layer_indexes):
layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()
layer_output = layer_output[b]
layers = collections.OrderedDict()
layers["index"] = layer_index
layers["values"] = [
round(x.item(), 6) for x in layer_output[i]
]
all_layers.append(layers)
out_features = collections.OrderedDict()
out_features["token"] = token
out_features["layers"] = all_layers
all_out_features.append(out_features)
output_json["features"] = all_out_features
writer.write(json.dumps(output_json) + "\n")
if __name__ == "__main__":
main()
This diff is collapsed.
icon.png

53.8 KB

This diff is collapsed.
DLLL {"timestamp": "1689419433.768081", "datetime": "2023-07-15 19:10:33.768081", "elapsedtime": "0.000237", "type": "LOG", "step": "PARAMETER", "data": {"Config": ["Namespace(amp=True, bert_model='bert-large-uncased', cache_dir=None, config_file='/public/home/hepj//model_source/pytorch_bert/bert_config.json', disable_progress_bar=False, dist_url='tcp://224.66.41.62:23456', do_eval=True, do_lower_case=False, do_predict=True, do_train=True, doc_stride=128, eval_script='./evaluate-v1.1.py', fp16=True, gpus_per_node=1, gradient_accumulation_steps=1, init_checkpoint='/public/home/hepj/model_source/model_pytorch.ckpt.pt', json_summary='./log/results-squad-fp16.json', learning_rate=5e-05, local_rank=-1, log_freq=50, loss_scale=0, max_answer_length=30, max_query_length=64, max_seq_length=384, max_steps=-1.0, n_best_size=20, no_cuda=False, null_score_diff_threshold=0.0, num_train_epochs=3.0, output_dir='/public/home/hepj/outdir/tourch/SQuAD', predict_batch_size=4, predict_file='/public/home/hepj/data/sq1.1/dev-v1.1.json', seed=42, skip_cache=False, skip_checkpoint=False, train_batch_size=4, train_file='/public/home/hepj/data/sq1.1/train-v1.1.json', use_env=False, verbose_logging=False, version_2_with_negative=False, vocab_file='/public/home/hepj//model_source/pytorch_bert/vocab.txt', warmup_proportion=0.1, world_size=1)"]}}
DLLL {"timestamp": "1689419433.787672", "datetime": "2023-07-15 19:10:33.787672", "elapsedtime": "0.019828", "type": "LOG", "step": "PARAMETER", "data": {"SEED": 42}}
DLLL {"timestamp": "1689419453.753193", "datetime": "2023-07-15 19:10:53.753193", "elapsedtime": "19.985349", "type": "LOG", "step": "PARAMETER", "data": {"loading_checkpoint": true}}
DLLL {"timestamp": "1689419456.642115", "datetime": "2023-07-15 19:10:56.642115", "elapsedtime": "22.874271", "type": "LOG", "step": "PARAMETER", "data": {"loaded_checkpoint": true}}
DLLL {"timestamp": "1689419457.266302", "datetime": "2023-07-15 19:10:57.266302", "elapsedtime": "23.498458", "type": "LOG", "step": "PARAMETER", "data": {"model_weights_num": 335150082}}
DLLL {"timestamp": "1689419469.543777", "datetime": "2023-07-15 19:11:09.543777", "elapsedtime": "35.775933", "type": "LOG", "step": "PARAMETER", "data": {"train_start": true}}
DLLL {"timestamp": "1689419469.543959", "datetime": "2023-07-15 19:11:09.543959", "elapsedtime": "35.776115", "type": "LOG", "step": "PARAMETER", "data": {"training_samples": 87599}}
DLLL {"timestamp": "1689419469.54403", "datetime": "2023-07-15 19:11:09.544030", "elapsedtime": "35.776186", "type": "LOG", "step": "PARAMETER", "data": {"training_features": 88368}}
DLLL {"timestamp": "1689419469.544095", "datetime": "2023-07-15 19:11:09.544095", "elapsedtime": "35.776251", "type": "LOG", "step": "PARAMETER", "data": {"train_batch_size": 4}}
DLLL {"timestamp": "1689419469.544156", "datetime": "2023-07-15 19:11:09.544156", "elapsedtime": "35.776312", "type": "LOG", "step": "PARAMETER", "data": {"steps": 65697.0}}
DLLL {"timestamp": "1689419476.360987", "datetime": "2023-07-15 19:11:16.360987", "elapsedtime": "42.593143", "type": "LOG", "step": [0, 1], "data": {"step_loss": 6.122858047485352, "learning_rate": 7.610697596541699e-09}}
DLLL {"timestamp": "1689419492.221115", "datetime": "2023-07-15 19:11:32.221115", "elapsedtime": "58.453271", "type": "LOG", "step": [0, 51], "data": {"step_loss": 5.114989757537842, "learning_rate": 3.8814557742362663e-07}}
DLLL {"timestamp": "1689419507.932752", "datetime": "2023-07-15 19:11:47.932752", "elapsedtime": "74.164908", "type": "LOG", "step": [0, 101], "data": {"step_loss": 5.053555488586426, "learning_rate": 7.686804572507116e-07}}
DLLL {"timestamp": "1689585949.809111", "datetime": "2023-07-17 17:25:49.809111", "elapsedtime": "0.000235", "type": "LOG", "step": "PARAMETER", "data": {"Config": ["Namespace(amp=False, bert_model='bert-large-uncased', cache_dir=None, config_file='/public/home/hepj/model_source/pytorch_bert/bert_config.json', disable_progress_bar=False, dist_url='tcp://224.66.41.62:23456', do_eval=False, do_lower_case=False, do_predict=True, do_train=True, doc_stride=128, eval_script='./evaluate-v1.1.py', fp16=False, gpus_per_node=1, gradient_accumulation_steps=1, init_checkpoint='/public/home/hepj/model_source/pytorch_bert/model.ckpt-28252.pt', json_summary='./log/results.json', learning_rate=5e-05, local_rank=-1, log_freq=50, loss_scale=0, max_answer_length=30, max_query_length=64, max_seq_length=384, max_steps=-1.0, n_best_size=20, no_cuda=False, null_score_diff_threshold=0.0, num_train_epochs=3.0, output_dir='/public/home/hepj/outdir/torch/SQuAD', predict_batch_size=4, predict_file='/public/home/hepj/data/sq1.1/dev-v1.1.json', seed=42, skip_cache=False, skip_checkpoint=False, train_batch_size=4, train_file='/public/home/hepj/data/sq1.1/train-v1.1.json', use_env=False, verbose_logging=False, version_2_with_negative=False, vocab_file='/public/home/hepj/model_source/pytorch_bert/vocab.txt', warmup_proportion=0.1, world_size=1)"]}}
DLLL {"timestamp": "1689585950.006137", "datetime": "2023-07-17 17:25:50.006137", "elapsedtime": "0.197261", "type": "LOG", "step": "PARAMETER", "data": {"SEED": 42}}
DLLL {"timestamp": "1689585970.324955", "datetime": "2023-07-17 17:26:10.324955", "elapsedtime": "20.516079", "type": "LOG", "step": "PARAMETER", "data": {"loading_checkpoint": true}}
DLLL {"timestamp": "1689585974.448674", "datetime": "2023-07-17 17:26:14.448674", "elapsedtime": "24.639798", "type": "LOG", "step": "PARAMETER", "data": {"loaded_checkpoint": true}}
DLLL {"timestamp": "1689585976.67685", "datetime": "2023-07-17 17:26:16.676850", "elapsedtime": "26.867974", "type": "LOG", "step": "PARAMETER", "data": {"model_weights_num": 335150082}}
DLLL {"timestamp": "1689585989.449134", "datetime": "2023-07-17 17:26:29.449134", "elapsedtime": "39.640258", "type": "LOG", "step": "PARAMETER", "data": {"train_start": true}}
DLLL {"timestamp": "1689585989.467614", "datetime": "2023-07-17 17:26:29.467614", "elapsedtime": "39.658738", "type": "LOG", "step": "PARAMETER", "data": {"training_samples": 87599}}
DLLL {"timestamp": "1689585989.467693", "datetime": "2023-07-17 17:26:29.467693", "elapsedtime": "39.658817", "type": "LOG", "step": "PARAMETER", "data": {"training_features": 88368}}
DLLL {"timestamp": "1689585989.467758", "datetime": "2023-07-17 17:26:29.467758", "elapsedtime": "39.658882", "type": "LOG", "step": "PARAMETER", "data": {"train_batch_size": 4}}
DLLL {"timestamp": "1689585989.46782", "datetime": "2023-07-17 17:26:29.467820", "elapsedtime": "39.658944", "type": "LOG", "step": "PARAMETER", "data": {"steps": 65697.0}}
DLLL {"timestamp": "1689586004.55256", "datetime": "2023-07-17 17:26:44.552560", "elapsedtime": "54.743684", "type": "LOG", "step": [0, 1], "data": {"step_loss": 6.121078014373779, "learning_rate": 5e-05}}
# 模型唯一标识
modelCode=309
# 模型名称
modelName=BERT_pytorch
# 模型描述
modelDescription=BERT是一种新的基于Transformer应用于计算机视觉领域的神经网络模型,基于PyTorch实现测试
# 应用场景
appScenario=训练,对话问答,互联网,教育,科研
# 框架类型
frameType=PyTorch
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""
import math
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
#from fused_adam_local import FusedAdam
from apex.optimizers import FusedAdam
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from utils import is_main_process
multi_tensor_l2norm = amp_C.multi_tensor_l2norm
lamb_compute_update = amp_C.multi_tensor_lamb_stage1_cuda
lamb_apply_update = amp_C.multi_tensor_lamb_stage2_cuda
scale = amp_C.multi_tensor_scale
def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x/warmup
return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x/warmup
return max((x - 1. )/ (warmup - 1.), 0.)
def warmup_poly(x, warmup=0.002, degree=0.5):
if x < warmup:
return x/warmup
return (1.0 - x)**degree
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
'warmup_poly':warmup_poly,
}
class BertAdam(Optimizer):
"""Implements BERT version of Adam algorithm with weight decay fix.
Params:
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6
weight_decay: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
max_grad_norm=1.0):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
if len(state) == 0:
return [0]
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
lr.append(lr_scheduled)
return lr
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['next_m'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['next_v'] = torch.zeros_like(p.data)
next_m, next_v = state['next_m'], state['next_v']
beta1, beta2 = group['b1'], group['b2']
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
next_m.mul_(beta1).add_(1 - beta1, grad)
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
update = next_m / (next_v.sqrt() + group['e'])
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
update_with_lr = lr_scheduled * update
p.data.add_(-update_with_lr)
state['step'] += 1
return loss
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment