"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fdc05cd68fa005f0e6bf7cff49f357c6c9a504a3"
Commit 20e65220 authored by lukovnikov's avatar lukovnikov
Browse files

relation classification: replacing entity mention with mask token

parent eac039d2
...@@ -130,7 +130,7 @@ class BertAdam(Optimizer): ...@@ -130,7 +130,7 @@ class BertAdam(Optimizer):
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
""" """
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, init_weight_decay=0.,
max_grad_norm=1.0): max_grad_norm=1.0):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
...@@ -150,7 +150,7 @@ class BertAdam(Optimizer): ...@@ -150,7 +150,7 @@ class BertAdam(Optimizer):
if warmup != -1 or t_total != -1: if warmup != -1 or t_total != -1:
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.") logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.")
defaults = dict(lr=lr, schedule=schedule, defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, b1=b1, b2=b2, e=e, weight_decay=weight_decay, init_weight_decay=init_weight_decay,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults) super(BertAdam, self).__init__(params, defaults)
...@@ -220,6 +220,8 @@ class BertAdam(Optimizer): ...@@ -220,6 +220,8 @@ class BertAdam(Optimizer):
if group['weight_decay'] > 0.0: if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data update += group['weight_decay'] * p.data
# TODO: init weight decay
lr_scheduled = group['lr'] lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step']) lr_scheduled *= group['schedule'].get_lr(state['step'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment