Commit 50b7e52a authored by thomwolf's avatar thomwolf
Browse files

WIP examples

parent ed6c8d37
...@@ -37,7 +37,7 @@ from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenc ...@@ -37,7 +37,7 @@ from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenc
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP) XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from pytorch_transformers import (BertTokenizer, XLNetTokenizer, from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
XLMTokenizer) XLMTokenizer)
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule from pytorch_transformers.optimization import BertAdam
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
...@@ -60,12 +60,12 @@ TOKENIZER_CLASSES = { ...@@ -60,12 +60,12 @@ TOKENIZER_CLASSES = {
'xlm': XLMTokenizer, 'xlm': XLMTokenizer,
} }
def train(args, train_dataset, model): def train(args, train_dataset, model, tokenizer):
""" Train the model """ """ Train the model """
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter() tb_writer = SummaryWriter()
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps args.train_batch_size = args.per_gpu_train_batch_size * args.n_gpu
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
...@@ -76,42 +76,36 @@ def train(args, train_dataset, model): ...@@ -76,42 +76,36 @@ def train(args, train_dataset, model):
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer # Prepare optimizer
param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.weight']
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
] ]
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate,
t_total=num_train_optimization_steps, warmup=args.warmup_proportion)
if args.fp16: if args.fp16:
try: try:
from apex.optimizers import FP16_Optimizer, FusedAdam from apex import amp
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, t_total=num_train_optimization_steps)
else:
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", num_train_optimization_steps) logger.info(" Total optimization steps = %d", num_train_optimization_steps)
global_step = 0 global_step = 0
tr_loss = 0 tr_loss, logging_loss = 0.0, 0.0
model.train()
optimizer.zero_grad() optimizer.zero_grad()
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]): for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1],
...@@ -125,23 +119,25 @@ def train(args, train_dataset, model): ...@@ -125,23 +119,25 @@ def train(args, train_dataset, model):
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps loss = loss / args.gradient_accumulation_steps
loss.backward() if not args.fp16 else optimizer.backward(loss) if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
if not args.fp16: if args.local_rank == -1: # Only evaluate on single GPU otherwise metrics may not average well
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) results = evaluate(args, model, tokenizer)
tb_writer.add_scalar('loss', loss.item(), global_step) for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
logging_loss = tr_loss
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
break break
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
...@@ -150,62 +146,71 @@ def train(args, train_dataset, model): ...@@ -150,62 +146,71 @@ def train(args, train_dataset, model):
return global_step, tr_loss / global_step return global_step, tr_loss / global_step
def evalutate(args, eval_task, eval_output_dir, dataset, model): def evaluate(args, model, tokenizer):
""" Evaluate the model """ # Loop to handle MNLI double evaluation (matched, mis-matched)
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
os.makedirs(eval_output_dir) eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
# Note that DistributedSampler samples randomly results = {}
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset) for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
# Eval! """ Evaluate the model """
logger.info("***** Running evaluation *****") if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
logger.info(" Num examples = %d", len(dataset)) os.makedirs(eval_output_dir)
logger.info(" Batch size = %d", args.eval_batch_size)
model.eval() # Note that DistributedSampler samples randomly
eval_loss = 0 eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
nb_eval_steps = 0 eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
preds = None
out_label_ids = None # Eval!
for batch in tqdm(eval_dataloader, desc="Evaluating"): logger.info("***** Running evaluation *****")
batch = tuple(t.to(args.device) for t in batch) logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
with torch.no_grad(): model.eval()
inputs = {'input_ids': batch[0], eval_loss = 0
'attention_mask': batch[1], nb_eval_steps = 0
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids preds = None
'labels': batch[3]} out_label_ids = None
outputs = model(**inputs) for batch in tqdm(eval_dataloader, desc="Evaluating"):
tmp_eval_loss, logits = outputs[:2] batch = tuple(t.to(args.device) for t in batch)
eval_loss += tmp_eval_loss.mean().item() with torch.no_grad():
nb_eval_steps += 1 inputs = {'input_ids': batch[0],
if preds is None: 'attention_mask': batch[1],
preds = logits.detach().cpu().numpy() 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
out_label_ids = inputs['labels'].detach().cpu().numpy() 'labels': batch[3]}
else: outputs = model(**inputs)
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) tmp_eval_loss, logits = outputs[:2]
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
eval_loss += tmp_eval_loss.mean().item()
eval_loss = eval_loss / nb_eval_steps nb_eval_steps += 1
if args.output_mode == "classification": if preds is None:
preds = np.argmax(preds, axis=1) preds = logits.detach().cpu().numpy()
elif args.output_mode == "regression": out_label_ids = inputs['labels'].detach().cpu().numpy()
preds = np.squeeze(preds) else:
result = compute_metrics(eval_task, preds, out_label_ids) preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: eval_loss = eval_loss / nb_eval_steps
logger.info("***** Eval results *****") if args.output_mode == "classification":
for key in sorted(result.keys()): preds = np.argmax(preds, axis=1)
logger.info(" %s = %s", key, str(result[key])) elif args.output_mode == "regression":
writer.write("%s = %s\n" % (key, str(result[key]))) preds = np.squeeze(preds)
result = compute_metrics(eval_task, preds, out_label_ids)
return result results.update(result)
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
def load_and_cache_examples(args, task, tokenizer, evaluate=False): with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
return results
def load_and_cache_examples(args, task, tokenizer, evaluate=False, overwrite_cache=False):
processor = processors[task]() processor = processors[task]()
output_mode = output_modes[task] output_mode = output_modes[task]
# Load data features from cache or dataset file # Load data features from cache or dataset file
...@@ -214,7 +219,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -214,7 +219,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
list(filter(None, args.model_name.split('/'))).pop(), list(filter(None, args.model_name.split('/'))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
str(task))) str(task)))
if os.path.exists(cached_features_file): if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
...@@ -270,39 +275,44 @@ def main(): ...@@ -270,39 +275,44 @@ def main():
help="Whether to run eval on the dev set.") help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case", action='store_true', parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.") help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size", default=32, type=int,
help="Total batch size for training.") parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
help="Batch size per GPU for training.")
parser.add_argument("--eval_batch_size", default=8, type=int, parser.add_argument("--eval_batch_size", default=8, type=int,
help="Total batch size for eval.") help="Total batch size for eval.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.") help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--learning_rate", default=5e-5, type=float, parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.") help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.") help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1, type=int, parser.add_argument("--max_steps", default=-1, type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.") help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
parser.add_argument("--warmup_proportion", default=0.1, type=float, parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).") help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
parser.add_argument('--logging_steps', type=int, default=100,
help="Log every X updates steps.")
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available") help="Avoid using CUDA when available")
parser.add_argument('--overwrite_output_dir', action='store_true', parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory") help="Overwrite the content of the output directory")
parser.add_argument('--overwrite_cache', action='store_true',
help="Overwrite the cached training and evaluation sets")
parser.add_argument('--seed', type=int, default=42, parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization") help="random seed for initialization")
parser.add_argument('--fp16', action='store_true', parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit") help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--loss_scale', type=float, default=0, parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"0 (default value): dynamic loss scaling.\n" "See details at https://nvidia.github.io/apex/amp.html")
"Positive power of 2: static loss scaling value.\n")
parser.add_argument("--local_rank", type=int, default=-1, parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus") help="For distributed training: local_rank")
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
...@@ -362,13 +372,10 @@ def main(): ...@@ -362,13 +372,10 @@ def main():
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() torch.distributed.barrier()
# Distributed, parrallel and fp16 model # Distributed and parrallel training
if args.fp16:
model.half()
model.to(args.device) model.to(args.device)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
device_ids=[args.local_rank],
output_device=args.local_rank, output_device=args.local_rank,
find_unused_parameters=True) find_unused_parameters=True)
elif args.n_gpu > 1: elif args.n_gpu > 1:
...@@ -377,7 +384,7 @@ def main(): ...@@ -377,7 +384,7 @@ def main():
# Training # Training
if args.do_train: if args.do_train:
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
global_step, tr_loss = train(args, train_dataset, model) global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
...@@ -402,17 +409,10 @@ def main(): ...@@ -402,17 +409,10 @@ def main():
model.to(args.device) model.to(args.device)
# Evaluation # Evaluation
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and args.local_rank in [-1, 0]:
# Handle MNLI double evaluation results = evaluate(args, model, tokenizer)
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
result = evalutate(args, eval_task, eval_output_dir, eval_dataset, model)
return result return results
if __name__ == "__main__": if __name__ == "__main__":
......
This diff is collapsed.
# Copyright (c) 2019-present, the HuggingFace Inc. authors.
# All rights reserved. This source code is licensed under the BSD-style
# license found in the LICENSE file in the root directory of this source tree.
import logging
import os
from tqdm import tqdm
from pprint import pformat
import torch
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint
from ignite.metrics import RunningAverage
from ignite.contrib.handlers import ProgressBar
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler, OutputHandler, TensorboardLogger
def average_distributed_scalar(scalar, args):
""" Average a scalar over nodes if we are in distributed training.
We use this for distributed evaluation.
Beware, such averages only works for metrics which are additive with regard
to the evaluation dataset, e.g. accuracy, log probabilities.
Doesn't work for ratio metrics like F1.
"""
if args.local_rank == -1:
return scalar
scalar_t = torch.tensor(scalar, dtype=torch.float, device=args.device) / torch.distributed.get_world_size()
torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM)
return scalar_t.item()
def add_logging_and_checkpoint_saving(trainer, evaluator, metrics, model, optimizer, args, prefix=""):
""" Add to a PyTorch ignite training engine tensorboard logging,
progress bar with average loss, checkpoint saving and save training config.
"""
# Add progress bar with average loss
RunningAverage(output_transform=lambda x: x).attach(trainer, prefix + "loss")
pbar = ProgressBar(persist=True)
pbar.attach(trainer, metric_names=[prefix + "loss"])
evaluator.add_event_handler(Events.COMPLETED, lambda _:
pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))
# Add tensorboard logging with training and evaluation metrics
tb_logger = TensorboardLogger(log_dir=None)
tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=[prefix + "loss"]),
event_name=Events.ITERATION_COMPLETED)
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer),
event_name=Events.ITERATION_STARTED)
@evaluator.on(Events.COMPLETED)
def tb_log_metrics(engine):
for name in metrics.keys():
tb_logger.writer.add_scalar(name, engine.state.metrics[name], trainer.state.iteration)
# Add checkpoint saving after each epoch - take care of distributed encapsulation ('getattr()')
checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})
# Save training configuration
torch.save(args, os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
return checkpoint_handler, tb_logger
...@@ -393,7 +393,7 @@ class XLNetRelativeAttention(nn.Module): ...@@ -393,7 +393,7 @@ class XLNetRelativeAttention(nn.Module):
x = x[1:, ...] x = x[1:, ...]
x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3]) x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
# x = x[:, 0:klen, :, :] # x = x[:, 0:klen, :, :]
x = torch.index_select(x, 1, torch.arange(klen)) x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
return x return x
......
...@@ -227,6 +227,8 @@ class BertAdam(Optimizer): ...@@ -227,6 +227,8 @@ class BertAdam(Optimizer):
lr = [] lr = []
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None:
continue
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
return [0] return [0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment