Commit 41a1b292 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents 9471054e 3d30899b
...@@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric ...@@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric from .distillation_metric import DistillationMetric
from .table_metric import TableMetric from .table_metric import TableMetric
from .kie_metric import KIEMetric from .kie_metric import KIEMetric
from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric
def build_metric(config): def build_metric(config):
support_dict = [ support_dict = [
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DetMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric' "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
'VQAReTokenMetric'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
class ClsMetric(object): class ClsMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5
self.reset() self.reset()
def __call__(self, pred_label, *args, **kwargs): def __call__(self, pred_label, *args, **kwargs):
...@@ -28,7 +29,7 @@ class ClsMetric(object): ...@@ -28,7 +29,7 @@ class ClsMetric(object):
all_num += 1 all_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return {'acc': correct_num / all_num, } return {'acc': correct_num / (all_num + self.eps), }
def get_metric(self): def get_metric(self):
""" """
...@@ -36,7 +37,7 @@ class ClsMetric(object): ...@@ -36,7 +37,7 @@ class ClsMetric(object):
'acc': 0 'acc': 0
} }
""" """
acc = self.correct_num / self.all_num acc = self.correct_num / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc} return {'acc': acc}
......
...@@ -20,6 +20,7 @@ class RecMetric(object): ...@@ -20,6 +20,7 @@ class RecMetric(object):
def __init__(self, main_indicator='acc', is_filter=False, **kwargs): def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.is_filter = is_filter self.is_filter = is_filter
self.eps = 1e-5
self.reset() self.reset()
def _normalize_text(self, text): def _normalize_text(self, text):
...@@ -47,8 +48,8 @@ class RecMetric(object): ...@@ -47,8 +48,8 @@ class RecMetric(object):
self.all_num += all_num self.all_num += all_num
self.norm_edit_dis += norm_edit_dis self.norm_edit_dis += norm_edit_dis
return { return {
'acc': correct_num / all_num, 'acc': correct_num / (all_num + self.eps),
'norm_edit_dis': 1 - norm_edit_dis / (all_num + 1e-3) 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
} }
def get_metric(self): def get_metric(self):
...@@ -58,8 +59,8 @@ class RecMetric(object): ...@@ -58,8 +59,8 @@ class RecMetric(object):
'norm_edit_dis': 0, 'norm_edit_dis': 0,
} }
""" """
acc = 1.0 * self.correct_num / (self.all_num + 1e-3) acc = 1.0 * self.correct_num / (self.all_num + self.eps)
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + 1e-3) norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis} return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
......
...@@ -12,9 +12,12 @@ ...@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
class TableMetric(object): class TableMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5
self.reset() self.reset()
def __call__(self, pred, batch, *args, **kwargs): def __call__(self, pred, batch, *args, **kwargs):
...@@ -31,9 +34,7 @@ class TableMetric(object): ...@@ -31,9 +34,7 @@ class TableMetric(object):
correct_num += 1 correct_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return { return {'acc': correct_num * 1.0 / (all_num + self.eps), }
'acc': correct_num * 1.0 / all_num,
}
def get_metric(self): def get_metric(self):
""" """
...@@ -41,7 +42,7 @@ class TableMetric(object): ...@@ -41,7 +42,7 @@ class TableMetric(object):
'acc': 0, 'acc': 0,
} }
""" """
acc = 1.0 * self.correct_num / self.all_num acc = 1.0 * self.correct_num / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc} return {'acc': acc}
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
__all__ = ['KIEMetric']
class VQAReTokenMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
pred_relations, relations, entities = preds
self.pred_relations_list.extend(pred_relations)
self.relations_list.extend(relations)
self.entities_list.extend(entities)
def get_metric(self):
gt_relations = []
for b in range(len(self.relations_list)):
rel_sent = []
for head, tail in zip(self.relations_list[b]["head"],
self.relations_list[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (self.entities_list[b]["start"][rel["head_id"]],
self.entities_list[b]["end"][rel["head_id"]])
rel["head_type"] = self.entities_list[b]["label"][rel[
"head_id"]]
rel["tail_id"] = tail
rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
self.entities_list[b]["end"][rel["tail_id"]])
rel["tail_type"] = self.entities_list[b]["label"][rel[
"tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = self.re_score(
self.pred_relations_list, gt_relations, mode="boundaries")
metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
"hmean": re_metrics["ALL"]["f1"],
}
self.reset()
return metrics
def reset(self):
self.pred_relations_list = []
self.relations_list = []
self.entities_list = []
def re_score(self, pred_relations, gt_relations, mode="strict"):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert mode in ["strict", "boundaries"]
relation_types = [v for v in [0, 1] if not v == 0]
scores = {
rel: {
"tp": 0,
"fp": 0,
"fn": 0
}
for rel in relation_types + ["ALL"]
}
# Count GT relations and Predicted relations
n_sents = len(gt_relations)
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
# Count TP, FP and FN per type
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
for rel_type in relation_types:
# strict mode takes argument types into account
if mode == "strict":
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in pred_sent
if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in gt_sent if rel["type"] == rel_type}
# boundaries mode only takes argument spans into account
elif mode == "boundaries":
pred_rels = {(rel["head"], rel["tail"])
for rel in pred_sent
if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["tail"])
for rel in gt_sent if rel["type"] == rel_type}
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
# Compute per entity Precision / Recall / F1
for rel_type in scores.keys():
if scores[rel_type]["tp"]:
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
scores[rel_type]["fp"] + scores[rel_type]["tp"])
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
scores[rel_type]["fn"] + scores[rel_type]["tp"])
else:
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
scores[rel_type]["f1"] = (
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
(scores[rel_type]["p"] + scores[rel_type]["r"]))
else:
scores[rel_type]["f1"] = 0
# Compute micro F1 Scores
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
if tp:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
else:
precision, recall, f1 = 0, 0, 0
scores["ALL"]["p"] = precision
scores["ALL"]["r"] = recall
scores["ALL"]["f1"] = f1
scores["ALL"]["tp"] = tp
scores["ALL"]["fp"] = fp
scores["ALL"]["fn"] = fn
# Compute Macro F1 Scores
scores["ALL"]["Macro_f1"] = np.mean(
[scores[ent_type]["f1"] for ent_type in relation_types])
scores["ALL"]["Macro_p"] = np.mean(
[scores[ent_type]["p"] for ent_type in relation_types])
scores["ALL"]["Macro_r"] = np.mean(
[scores[ent_type]["r"] for ent_type in relation_types])
return scores
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
__all__ = ['KIEMetric']
class VQASerTokenMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
preds, labels = preds
self.pred_list.extend(preds)
self.gt_list.extend(labels)
def get_metric(self):
from seqeval.metrics import f1_score, precision_score, recall_score
metircs = {
"precision": precision_score(self.gt_list, self.pred_list),
"recall": recall_score(self.gt_list, self.pred_list),
"hmean": f1_score(self.gt_list, self.pred_list),
}
self.reset()
return metircs
def reset(self):
self.pred_list = []
self.gt_list = []
...@@ -63,6 +63,10 @@ class BaseModel(nn.Layer): ...@@ -63,6 +63,10 @@ class BaseModel(nn.Layer):
in_channels = self.neck.out_channels in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls # # build head, head is need for det, rec and cls
if 'Head' not in config or config['Head'] is None:
self.use_head = False
else:
self.use_head = True
config["Head"]['in_channels'] = in_channels config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"]) self.head = build_head(config["Head"])
...@@ -77,6 +81,7 @@ class BaseModel(nn.Layer): ...@@ -77,6 +81,7 @@ class BaseModel(nn.Layer):
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
y["neck_out"] = x y["neck_out"] = x
if self.use_head:
x = self.head(x, targets=data) x = self.head(x, targets=data)
if isinstance(x, dict): if isinstance(x, dict):
y.update(x) y.update(x)
......
...@@ -29,9 +29,10 @@ def build_backbone(config, model_type): ...@@ -29,9 +29,10 @@ def build_backbone(config, model_type):
from .rec_nrtr_mtb import MTB from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31 from .rec_resnet_31 import ResNet31
from .rec_resnet_aster import ResNet_ASTER from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet
support_dict = [ support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER" "ResNet31", "ResNet_ASTER", 'MicroNet'
] ]
elif model_type == "e2e": elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
...@@ -43,6 +44,9 @@ def build_backbone(config, model_type): ...@@ -43,6 +44,9 @@ def build_backbone(config, model_type):
from .table_resnet_vd import ResNet from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3 from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"] support_dict = ["ResNet", "MobileNetV3"]
elif model_type == 'vqa':
from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe
support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe']
else: else:
raise NotImplementedError raise NotImplementedError
......
This diff is collapsed.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from paddle import nn
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
pretrained_model_dict = {
LayoutXLMModel: 'layoutxlm-base-uncased',
LayoutLMModel: 'layoutlm-base-uncased'
}
class NLPBaseModel(nn.Layer):
def __init__(self,
base_model_class,
model_class,
type='ser',
pretrained=True,
checkpoints=None,
**kwargs):
super(NLPBaseModel, self).__init__()
if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints)
else:
pretrained_model_name = pretrained_model_dict[base_model_class]
if pretrained:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
base_model = base_model_class(
**base_model_class.pretrained_init_configuration[
pretrained_model_name])
if type == 'ser':
self.model = model_class(
base_model, num_classes=kwargs['num_classes'], dropout=None)
else:
self.model = model_class(base_model, dropout=None)
self.out_channels = 1
class LayoutXLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
'ser',
pretrained,
checkpoints,
num_classes=num_classes)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
image=x[3],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
head_mask=None,
labels=None)
return x[0]
class LayoutLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
'ser',
pretrained,
checkpoints,
num_classes=num_classes)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
output_hidden_states=False)
return x
class LayoutXLMForRe(NLPBaseModel):
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
LayoutXLMForRelationExtraction,
're', pretrained, checkpoints)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[1],
labels=None,
image=x[2],
attention_mask=x[3],
token_type_ids=x[4],
position_ids=None,
head_mask=None,
entities=x[5],
relations=x[6])
return x
...@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step2 build regularization # step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None: if 'regularizer' in config and config['regularizer'] is not None:
reg_config = config.pop('regularizer') reg_config = config.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay' reg_name = reg_config.pop('name')
if not hasattr(regularizer, reg_name):
reg_name += 'Decay'
reg = getattr(regularizer, reg_name)(**reg_config)() reg = getattr(regularizer, reg_name)(**reg_config)()
else: else:
reg = None reg = None
......
...@@ -18,7 +18,7 @@ from __future__ import print_function ...@@ -18,7 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
from paddle.optimizer import lr from paddle.optimizer import lr
from .lr_scheduler import CyclicalCosineDecay from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay
class Linear(object): class Linear(object):
...@@ -226,3 +226,53 @@ class CyclicalCosine(object): ...@@ -226,3 +226,53 @@ class CyclicalCosine(object):
end_lr=self.learning_rate, end_lr=self.learning_rate,
last_epoch=self.last_epoch) last_epoch=self.last_epoch)
return learning_rate return learning_rate
class OneCycle(object):
"""
One Cycle learning rate decay
Args:
max_lr(float): Upper learning rate boundaries
epochs(int): total training epochs
step_each_epoch(int): steps each epoch
anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing.
Default: ‘cos’
three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’
instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’).
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
max_lr,
epochs,
step_each_epoch,
anneal_strategy='cos',
three_phase=False,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(OneCycle, self).__init__()
self.max_lr = max_lr
self.epochs = epochs
self.steps_per_epoch = step_each_epoch
self.anneal_strategy = anneal_strategy
self.three_phase = three_phase
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
def __call__(self):
learning_rate = OneCycleDecay(
max_lr=self.max_lr,
epochs=self.epochs,
steps_per_epoch=self.steps_per_epoch,
anneal_strategy=self.anneal_strategy,
three_phase=self.three_phase,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.max_lr,
last_epoch=self.last_epoch)
return learning_rate
\ No newline at end of file
...@@ -47,3 +47,116 @@ class CyclicalCosineDecay(LRScheduler): ...@@ -47,3 +47,116 @@ class CyclicalCosineDecay(LRScheduler):
lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \ lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \
(1 + math.cos(math.pi * reletive_epoch / self.cycle)) (1 + math.cos(math.pi * reletive_epoch / self.cycle))
return lr return lr
class OneCycleDecay(LRScheduler):
"""
One Cycle learning rate decay
A learning rate which can be referred in https://arxiv.org/abs/1708.07120
Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
"""
def __init__(self,
max_lr,
epochs=None,
steps_per_epoch=None,
pct_start=0.3,
anneal_strategy='cos',
div_factor=25.,
final_div_factor=1e4,
three_phase=False,
last_epoch=-1,
verbose=False):
# Validate total_steps
if epochs <= 0 or not isinstance(epochs, int):
raise ValueError(
"Expected positive integer epochs, but got {}".format(epochs))
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
raise ValueError(
"Expected positive integer steps_per_epoch, but got {}".format(
steps_per_epoch))
self.total_steps = epochs * steps_per_epoch
self.max_lr = max_lr
self.initial_lr = self.max_lr / div_factor
self.min_lr = self.initial_lr / final_div_factor
if three_phase:
self._schedule_phases = [
{
'end_step': float(pct_start * self.total_steps) - 1,
'start_lr': self.initial_lr,
'end_lr': self.max_lr,
},
{
'end_step': float(2 * pct_start * self.total_steps) - 2,
'start_lr': self.max_lr,
'end_lr': self.initial_lr,
},
{
'end_step': self.total_steps - 1,
'start_lr': self.initial_lr,
'end_lr': self.min_lr,
},
]
else:
self._schedule_phases = [
{
'end_step': float(pct_start * self.total_steps) - 1,
'start_lr': self.initial_lr,
'end_lr': self.max_lr,
},
{
'end_step': self.total_steps - 1,
'start_lr': self.max_lr,
'end_lr': self.min_lr,
},
]
# Validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError(
"Expected float between 0 and 1 pct_start, but got {}".format(
pct_start))
# Validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError(
"anneal_strategy must by one of 'cos' or 'linear', instead got {}".
format(anneal_strategy))
elif anneal_strategy == 'cos':
self.anneal_func = self._annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = self._annealing_linear
super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose)
def _annealing_cos(self, start, end, pct):
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
def _annealing_linear(self, start, end, pct):
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
return (end - start) * pct + start
def get_lr(self):
computed_lr = 0.0
step_num = self.last_epoch
if step_num > self.total_steps:
raise ValueError(
"Tried to step {} times. The specified number of total steps is {}"
.format(step_num + 1, self.total_steps))
start_step = 0
for i, phase in enumerate(self._schedule_phases):
end_step = phase['end_step']
if step_num <= end_step or i == len(self._schedule_phases) - 1:
pct = (step_num - start_step) / (end_step - start_step)
computed_lr = self.anneal_func(phase['start_lr'],
phase['end_lr'], pct)
break
start_step = phase['end_step']
return computed_lr
...@@ -158,3 +158,38 @@ class Adadelta(object): ...@@ -158,3 +158,38 @@ class Adadelta(object):
name=self.name, name=self.name,
parameters=parameters) parameters=parameters)
return opt return opt
class AdamW(object):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
weight_decay=0.01,
grad_clip=None,
name=None,
lazy_mode=False,
**kwargs):
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.learning_rate = learning_rate
self.weight_decay = 0.01 if weight_decay is None else weight_decay
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
def __call__(self, parameters):
opt = optim.AdamW(
learning_rate=self.learning_rate,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
name=self.name,
lazy_mode=self.lazy_mode,
parameters=parameters)
return opt
...@@ -29,24 +29,23 @@ class L1Decay(object): ...@@ -29,24 +29,23 @@ class L1Decay(object):
def __init__(self, factor=0.0): def __init__(self, factor=0.0):
super(L1Decay, self).__init__() super(L1Decay, self).__init__()
self.regularization_coeff = factor self.coeff = factor
def __call__(self): def __call__(self):
reg = paddle.regularizer.L1Decay(self.regularization_coeff) reg = paddle.regularizer.L1Decay(self.coeff)
return reg return reg
class L2Decay(object): class L2Decay(object):
""" """
L2 Weight Decay Regularization, which encourages the weights to be sparse. L2 Weight Decay Regularization, which helps to prevent the model over-fitting.
Args: Args:
factor(float): regularization coeff. Default:0.0. factor(float): regularization coeff. Default:0.0.
""" """
def __init__(self, factor=0.0): def __init__(self, factor=0.0):
super(L2Decay, self).__init__() super(L2Decay, self).__init__()
self.regularization_coeff = factor self.coeff = float(factor)
def __call__(self): def __call__(self):
reg = paddle.regularizer.L2Decay(self.regularization_coeff) return self.coeff
return reg \ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment