Commit 799a38c5 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #616 failed with stages
in 0 seconds
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import math
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn.functional as F
import numpy as np
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
@dataclass
class AdjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
label_smoothing: float = field(
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
report_accuracy: bool = field(
default=False,
metadata={"help": "report accuracy metric"},
)
ignore_prefix_size: int = field(
default=0,
metadata={"help": "Ignore first N tokens"},
)
ignore_eos: bool = field(
default=False,
metadata={"help": "Ignore eos token"},
)
sentence_avg: bool = II("optimization.sentence_avg")
drop_worst_ratio: float = field(
default=0.0,
metadata={"help": "ratio for discarding bad samples"},
)
drop_worst_after: int = field(
default=0,
metadata={"help": "steps for discarding bad samples"},
)
use_rdrop: bool = field(
default=False, metadata={"help": "use R-Drop"}
)
reg_alpha: float = field(
default=1.0, metadata={"help": "weight for R-Drop"}
)
sample_patch_num: int = field(
default=196, metadata={"help": "sample patches for v1"}
)
constraint_range: Optional[str] = field(
default=None,
metadata={"help": "constraint range"}
)
def construct_rdrop_sample(x):
if isinstance(x, dict):
for key in x:
x[key] = construct_rdrop_sample(x[key])
return x
elif isinstance(x, torch.Tensor):
return x.repeat(2, *([1] * (x.dim()-1)))
elif isinstance(x, int):
return x * 2
elif isinstance(x, np.ndarray):
return x.repeat(2)
else:
raise NotImplementedError
def kl_loss(p, q):
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
loss = (p_loss + q_loss) / 2
return loss
def label_smoothed_nll_loss(
lprobs, target, epsilon, update_num, reduce=True,
drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
constraint_masks=None, constraint_start=None, constraint_end=None
):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
if constraint_masks is not None:
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
elif constraint_start is not None and constraint_end is not None:
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
else:
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
eps_i = epsilon / (lprobs.size(-1) - 1)
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
if drop_worst_ratio > 0 and update_num > drop_worst_after:
if use_rdrop:
true_batch_size = loss.size(0) // 2
_, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
else:
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
nll_loss = nll_loss[indices]
lprobs = lprobs[indices]
ntokens = loss.numel()
nll_loss = nll_loss.sum()
loss = loss.sum()
if use_rdrop:
true_batch_size = lprobs.size(0) // 2
p = lprobs[:true_batch_size]
q = lprobs[true_batch_size:]
if constraint_start is not None and constraint_end is not None:
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
p = p[:, constraint_range]
q = q[:, constraint_range]
loss += kl_loss(p, q) * reg_alpha
return loss, nll_loss, ntokens
@register_criterion(
"adjust_label_smoothed_cross_entropy", dataclass=AdjustLabelSmoothedCrossEntropyCriterionConfig
)
class AdjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
ignore_eos=False,
report_accuracy=False,
drop_worst_ratio=0,
drop_worst_after=0,
use_rdrop=False,
reg_alpha=1.0,
sample_patch_num=196,
constraint_range=None
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.eps = label_smoothing
self.ignore_prefix_size = ignore_prefix_size
self.ignore_eos = ignore_eos
self.report_accuracy = report_accuracy
self.drop_worst_ratio = drop_worst_ratio
self.drop_worst_after = drop_worst_after
self.use_rdrop = use_rdrop
self.reg_alpha = reg_alpha
self.sample_patch_num = sample_patch_num
self.constraint_start = None
self.constraint_end = None
if constraint_range is not None:
constraint_start, constraint_end = constraint_range.split(',')
self.constraint_start = int(constraint_start)
self.constraint_end = int(constraint_end)
def forward(self, model, sample, update_num=0, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
if isinstance(sample, list):
if self.sample_patch_num > 0:
sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
sample_size = 1
logging_output = {
"loss": loss.data,
"loss_v1": loss_v1.data,
"loss_v2": loss_v2.data,
"nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
"ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
"nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
"sample_size": 1,
"sample_size_v1": sample_size_v1,
"sample_size_v2": sample_size_v2,
}
return loss, sample_size, logging_output
if self.use_rdrop:
construct_rdrop_sample(sample)
net_output = model(**sample["net_input"])
loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else ntokens
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output
def get_lprobs_and_target(self, model, net_output, sample):
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
constraint_masks = None
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
constraint_masks = sample["constraint_masks"]
net_output[0] = net_output[0].masked_fill(~constraint_masks, -math.inf)
if self.constraint_start is not None and self.constraint_end is not None:
net_output[0][:, :, 4:self.constraint_start] = -math.inf
net_output[0][:, :, self.constraint_end:] = -math.inf
lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0:
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
if constraint_masks is not None:
constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
if self.ignore_eos:
bsz, seq_len, embed_dim = lprobs.size()
eos_indices = target.eq(self.task.tgt_dict.eos())
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
target = target[~eos_indices].reshape(bsz, seq_len-1)
if constraint_masks is not None:
constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
if constraint_masks is not None:
constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
def compute_loss(self, model, net_output, sample, update_num, reduce=True):
lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
if constraint_masks is not None:
constraint_masks = constraint_masks[target != self.padding_idx]
lprobs = lprobs[target != self.padding_idx]
target = target[target != self.padding_idx]
loss, nll_loss, ntokens = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
update_num,
reduce=reduce,
drop_worst_ratio=self.drop_worst_ratio,
drop_worst_after=self.drop_worst_after,
use_rdrop=self.use_rdrop,
reg_alpha=self.reg_alpha,
constraint_masks=constraint_masks,
constraint_start=self.constraint_start,
constraint_end=self.constraint_end
)
return loss, nll_loss, ntokens
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size, sample_size, round=3
)
metrics.log_scalar(
"loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
)
metrics.log_scalar(
"loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / sample_size, ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
metrics.log_scalar(
"ntokens", ntokens, 1, round=3
)
metrics.log_scalar(
"nsentences", nsentences, 1, round=3
)
metrics.log_scalar(
"sample_size", sample_size, 1, round=3
)
metrics.log_scalar(
"sample_size_v1", sample_size_v1, 1, round=3
)
metrics.log_scalar(
"sample_size_v2", sample_size_v2, 1, round=3
)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("n_correct", n_correct)
metrics.log_derived(
"accuracy",
lambda meters: round(
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn.functional as F
import numpy as np
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
@dataclass
class AdjustLabelSmoothedEncouragingLossConfig(FairseqDataclass):
label_smoothing: float = field(
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
report_accuracy: bool = field(
default=False,
metadata={"help": "report accuracy metric"},
)
ignore_prefix_size: int = field(
default=0,
metadata={"help": "Ignore first N tokens"},
)
ignore_eos: bool = field(
default=False,
metadata={"help": "Ignore eos token"},
)
sentence_avg: bool = II("optimization.sentence_avg")
drop_worst_ratio: float = field(
default=0.0,
metadata={"help": "ratio for discarding bad samples"},
)
drop_worst_after: int = field(
default=0,
metadata={"help": "steps for discarding bad samples"},
)
use_rdrop: bool = field(
default=False, metadata={"help": "use R-Drop"}
)
reg_alpha: float = field(
default=1.0, metadata={"help": "weight for R-Drop"}
)
sample_patch_num: int = field(
default=196, metadata={"help": "sample patchs for v1"}
)
constraint_range: Optional[str] = field(
default=None,
metadata={"help": "constraint range"}
)
log_end: float = field(
default=0.75,
metadata={"help": "higher log_end is for cases with higher performance,"
" we recommend 0.75 or 0.5 as your first try."}
)
drop_best_ratio: float = field(
default=0.0,
metadata={"help": "ratio for discarding best samples"},
)
drop_best_after: int = field(
default=0,
metadata={"help": "steps for discarding best samples"},
)
def construct_rdrop_sample(x):
if isinstance(x, dict):
for key in x:
x[key] = construct_rdrop_sample(x[key])
return x
elif isinstance(x, torch.Tensor):
return x.repeat(2, *([1] * (x.dim()-1)))
elif isinstance(x, int):
return x * 2
elif isinstance(x, np.ndarray):
return x.repeat(2)
else:
raise NotImplementedError
def kl_loss(p, q):
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
loss = (p_loss + q_loss) / 2
return loss
def label_smoothed_nll_loss(
lprobs, target, epsilon, update_num, reduce=True,
drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
constraint_masks=None, constraint_start=None, constraint_end=None, drop_best_ratio=0.0,
drop_best_after=0,
):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
if constraint_masks is not None:
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
elif constraint_start is not None and constraint_end is not None:
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
else:
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
eps_i = epsilon / (lprobs.size(-1) - 1)
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
if drop_worst_ratio > 0 and update_num > drop_worst_after:
if use_rdrop:
true_batch_size = loss.size(0) // 2
_, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
else:
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
nll_loss = nll_loss[indices]
lprobs = lprobs[indices]
target = target[indices]
if update_num > drop_best_after:
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_best_ratio)), largest=True)
nll_loss = nll_loss[indices]
lprobs = lprobs[indices]
target = target[indices]
ntokens = loss.numel()
nll_loss = nll_loss.sum()
loss = loss.sum()
if use_rdrop:
true_batch_size = lprobs.size(0) // 2
p = lprobs[:true_batch_size]
q = lprobs[true_batch_size:]
if constraint_start is not None and constraint_end is not None:
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
p = p[:, constraint_range]
q = q[:, constraint_range]
loss += kl_loss(p, q) * reg_alpha
return loss, nll_loss, ntokens,lprobs,target
@register_criterion(
"adjust_label_smoothed_encouraging_loss", dataclass=AdjustLabelSmoothedEncouragingLossConfig
)
class AdjustLabelSmoothedEncouragingLossCriterion(FairseqCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
ignore_eos=False,
report_accuracy=False,
drop_worst_ratio=0,
drop_worst_after=0,
use_rdrop=False,
reg_alpha=1.0,
sample_patch_num=196,
constraint_range=None,
log_end=0.75,
drop_best_ratio=0.0,
drop_best_after=0,
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.eps = label_smoothing
self.ignore_prefix_size = ignore_prefix_size
self.ignore_eos = ignore_eos
self.report_accuracy = report_accuracy
self.drop_worst_ratio = drop_worst_ratio
self.drop_worst_after = drop_worst_after
self.use_rdrop = use_rdrop
self.reg_alpha = reg_alpha
self.sample_patch_num = sample_patch_num
self.constraint_start = None
self.constraint_end = None
if constraint_range is not None:
constraint_start, constraint_end = constraint_range.split(',')
self.constraint_start = int(constraint_start)
self.constraint_end = int(constraint_end)
self.log_end = log_end
self.drop_best_ratio = drop_best_ratio
self.drop_best_after = drop_best_after
print('el, self.log_end=', self.log_end)
# @staticmethod
# def add_args(parser):
# """Add criterion-specific arguments to the parser."""
# # fmt: off
# parser.add_argument('--log_end', type=float, default=1.0)
def forward(self, model, sample, update_num=0, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
if isinstance(sample, list):
if self.sample_patch_num > 0:
sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
sample_size = 1
logging_output = {
"loss": loss.data,
"loss_v1": loss_v1.data,
"loss_v2": loss_v2.data,
"nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
"ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
"nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
"sample_size": 1,
"sample_size_v1": sample_size_v1,
"sample_size_v2": sample_size_v2,
}
return loss, sample_size, logging_output
if self.use_rdrop:
construct_rdrop_sample(sample)
net_output = model(**sample["net_input"])
loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else ntokens
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output
def get_lprobs_and_target(self, model, net_output, sample):
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
constraint_masks = None
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
constraint_masks = sample["constraint_masks"]
net_output[0].masked_fill_(~constraint_masks, -math.inf)
if self.constraint_start is not None and self.constraint_end is not None:
net_output[0][:, :, 4:self.constraint_start] = -math.inf
net_output[0][:, :, self.constraint_end:] = -math.inf
lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0:
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
if constraint_masks is not None:
constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
if self.ignore_eos:
bsz, seq_len, embed_dim = lprobs.size()
eos_indices = target.eq(self.task.tgt_dict.eos())
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
target = target[~eos_indices].reshape(bsz, seq_len-1)
if constraint_masks is not None:
constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
if constraint_masks is not None:
constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
def compute_loss(self, model, net_output, sample, update_num, reduce=True):
lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
if constraint_masks is not None:
constraint_masks = constraint_masks[target != self.padding_idx]
lprobs = lprobs[target != self.padding_idx]
target = target[target != self.padding_idx]
loss, nll_loss, ntokens, lprobs, target = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
update_num,
reduce=reduce,
drop_worst_ratio=self.drop_worst_ratio,
drop_worst_after=self.drop_worst_after,
use_rdrop=self.use_rdrop,
reg_alpha=self.reg_alpha,
constraint_masks=constraint_masks,
constraint_start=self.constraint_start,
constraint_end=self.constraint_end
)
# for encouraging loss
probs = torch.exp(lprobs)
bonus = torch.log(torch.clamp((torch.ones_like(probs) - probs), min=1e-5)) # likelihood bonus
log_end = self.log_end
if log_end != 1.0: # e.g. 0.9
y_log_end = torch.log(torch.ones_like(probs) - log_end)
bonus_after_log_end = 1 / (log_end - torch.ones_like(probs)) * (probs - log_end) + y_log_end
# x:log_end, y torch.log(torch.clamp((torch.ones_like(probs) - probs), min=self.cl_eps))
bonus = torch.where(probs > log_end, bonus_after_log_end, bonus)
c_loss = F.nll_loss(
-bonus,
target.view(-1),
reduction='sum',
)
smoothing_c_loss = bonus.sum(dim=-1)
smoothing_c_loss = smoothing_c_loss.sum()
c_loss = c_loss * (1 - self.eps) + (self.eps / lprobs.size(-1)) * smoothing_c_loss
loss = loss + c_loss
# end for encouraging loss
return loss, nll_loss, ntokens
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size, sample_size, round=3
)
metrics.log_scalar(
"loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
)
metrics.log_scalar(
"loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / sample_size, ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
metrics.log_scalar(
"ntokens", ntokens, 1, round=3
)
metrics.log_scalar(
"nsentences", nsentences, 1, round=3
)
metrics.log_scalar(
"sample_size", sample_size, 1, round=3
)
metrics.log_scalar(
"sample_size_v1", sample_size_v1, 1, round=3
)
metrics.log_scalar(
"sample_size_v2", sample_size_v2, 1, round=3
)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("n_correct", n_correct)
metrics.log_derived(
"accuracy",
lambda meters: round(
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import math
import string
from dataclasses import dataclass, field
from collections import OrderedDict
from typing import Optional
import torch
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
from data import data_utils
from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
loss.masked_fill_(pad_mask, 0.0)
ntokens = (~pad_mask).sum()
else:
loss = loss.squeeze(-1)
ntokens = target.numel()
if reduce:
loss = loss.sum()
return loss, ntokens
@dataclass
class ScstRewardCriterionConfig(FairseqDataclass):
scst_cider_cached_tokens: str = field(
default="coco-train-words.p",
metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
)
ignore_prefix_size: int = field(
default=0,
metadata={"help": "Ignore first N tokens"},
)
sentence_avg: bool = II("optimization.sentence_avg")
constraint_range: Optional[str] = field(
default=None,
metadata={"help": "constraint range"}
)
@register_criterion(
"scst_reward_criterion", dataclass=ScstRewardCriterionConfig
)
class ScstRewardCriterion(FairseqCriterion):
CIDER_REWARD_WEIGHT = 1
def __init__(
self,
task,
scst_cider_cached_tokens,
sentence_avg,
ignore_prefix_size=0,
constraint_range=None
):
super().__init__(task)
self.scst_cider_scorer = CiderD(df=scst_cider_cached_tokens)
self.sentence_avg = sentence_avg
self.ignore_prefix_size = ignore_prefix_size
self.transtab = str.maketrans({key: None for key in string.punctuation})
self.constraint_start = None
self.constraint_end = None
if constraint_range is not None:
constraint_start, constraint_end = constraint_range.split(',')
self.constraint_start = int(constraint_start)
self.constraint_end = int(constraint_end)
def forward(self, model, sample, update_num=0, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
sample_size = (
nsentences if self.sentence_avg else ntokens
)
logging_output = {
"loss": loss.data,
"score": score,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return loss, sample_size, logging_output
def _calculate_eval_scores(self, gen_res, gt_idx, gt_res):
'''
gen_res: generated captions, list of str
gt_idx: list of int, of the same length as gen_res
gt_res: ground truth captions, list of list of str.
gen_res[i] corresponds to gt_res[gt_idx[i]]
Each image can have multiple ground truth captions
'''
gen_res_size = len(gen_res)
res = OrderedDict()
for i in range(gen_res_size):
res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))]
gts = OrderedDict()
gt_res_ = [
[self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))]
for i in range(len(gt_res))
]
for i in range(gen_res_size):
gts[i] = gt_res_[gt_idx[i]]
res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
_, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_)
scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores
return scores
@classmethod
def _wrap_sentence(self, s):
# ensure the sentence ends with <eos> token
# in order to keep consisitent with cider_cached_tokens
r = s.strip()
if r.endswith('.'):
r = r[:-1]
r += ' <eos>'
return r
def get_generator_out(self, model, sample):
def decode(toks):
hypo = toks.int().cpu()
hypo_str = self.task.tgt_dict.string(hypo)
hypo_str = self.task.bpe.decode(hypo_str).strip()
return hypo, hypo_str
model.eval()
with torch.no_grad():
self.task.scst_generator.model.eval()
gen_out = self.task.scst_generator.generate([model], sample)
gen_target = []
gen_res = []
gt_res = []
for i in range(len(gen_out)):
for j in range(len(gen_out[i])):
hypo, hypo_str = decode(gen_out[i][j]["tokens"])
gen_target.append(hypo)
gen_res.append(hypo_str)
gt_res.append(
decode(utils.strip_pad(sample["target"][i], self.padding_idx))[1].split('&&')
)
return gen_target, gen_res, gt_res
def get_reward_and_scores(self, gen_res, gt_res, device):
batch_size = len(gt_res)
gen_res_size = len(gen_res)
seq_per_img = gen_res_size // batch_size
gt_idx = [i // seq_per_img for i in range(gen_res_size)]
scores = self._calculate_eval_scores(gen_res, gt_idx, gt_res)
sc_ = scores.reshape(batch_size, seq_per_img)
baseline = (sc_.sum(1, keepdims=True) - sc_) / (sc_.shape[1] - 1)
# sample - baseline
reward = scores.reshape(batch_size, seq_per_img)
reward = reward - baseline
reward = reward.reshape(gen_res_size)
reward = torch.as_tensor(reward, device=device, dtype=torch.float64)
return reward, scores
def get_net_output(self, model, sample, gen_target):
def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
return data_utils.collate_tokens(
sample_list,
pad_idx=self.padding_idx,
eos_idx=eos,
left_pad=False,
move_eos_to_beginning=move_eos_to_beginning,
)
batch_size = len(sample["target"])
gen_target_size = len(gen_target)
seq_per_img = gen_target_size // batch_size
model.train()
sample_src_tokens = torch.repeat_interleave(
sample['net_input']['src_tokens'], seq_per_img, dim=0
)
sample_src_lengths = torch.repeat_interleave(
sample['net_input']['src_lengths'], seq_per_img, dim=0
)
sample_patch_images = torch.repeat_interleave(
sample['net_input']['patch_images'], seq_per_img, dim=0
)
sample_patch_masks = torch.repeat_interleave(
sample['net_input']['patch_masks'], seq_per_img, dim=0
)
gen_prev_output_tokens = torch.as_tensor(
merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
device=sample["target"].device, dtype=torch.int64
)
gen_target_tokens = torch.as_tensor(
merge(gen_target), device=sample["target"].device, dtype=torch.int64
)
net_output = model(
src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
patch_images=sample_patch_images, patch_masks=sample_patch_masks,
prev_output_tokens=gen_prev_output_tokens
)
return net_output, gen_target_tokens
def get_lprobs_and_target(self, model, net_output, gen_target):
if self.constraint_start is not None and self.constraint_end is not None:
net_output[0][:, :, 4:self.constraint_start] = -math.inf
net_output[0][:, :, self.constraint_end:] = -math.inf
lprobs = model.get_normalized_probs(net_output, log_probs=True)
if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False):
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
return lprobs, gen_target
def compute_loss(self, model, sample, reduce=True):
gen_target, gen_res, gt_res = self.get_generator_out(model, sample)
reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device)
net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
nsentences = gen_target_tokens.size(0)
return loss, scores.sum(), ntokens, nsentences
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
score_sum = sum(log.get("score", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size, sample_size, round=3
)
metrics.log_scalar(
"score", score_sum / nsentences, nsentences, round=3
)
metrics.log_scalar(
"ntokens", ntokens, 1, round=3
)
metrics.log_scalar(
"nsentences", nsentences, 1, round=3
)
metrics.log_scalar(
"sample_size", sample_size, 1, round=3
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
This diff is collapsed.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
from io import BytesIO
import logging
import warnings
import functools
import numpy as np
import torch
import base64
from torchvision import transforms
from timm.data import create_transform
from utils.vision_helper import RandomAugment
from PIL import Image, ImageFile
from data import data_utils
from data.ofa_dataset import OFADataset
ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
id = np.array([s["id"] for s in samples])
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
conf = None
if samples[0].get("conf", None) is not None:
conf = torch.cat([s['conf'] for s in samples], dim=0)
ref_dict = None
if samples[0].get("ref_dict", None) is not None:
ref_dict = np.array([s['ref_dict'] for s in samples])
constraint_masks = None
if samples[0].get("constraint_mask", None) is not None:
constraint_masks = merge("constraint_mask")
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"patch_images": patch_images,
"patch_masks": patch_masks,
"prev_output_tokens": prev_output_tokens
},
"conf": conf,
"ref_dict": ref_dict,
"constraint_masks": constraint_masks,
"target": target,
}
return batch
class ImageClassifyDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=128,
max_tgt_length=30,
patch_image_size=224,
constraint_trie=None,
imagenet_default_mean_and_std=False
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.patch_image_size = patch_image_size
self.constraint_trie = constraint_trie
if imagenet_default_mean_and_std:
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
else:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
if self.split != 'train':
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert("RGB"),
transforms.Resize([patch_image_size, patch_image_size], interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
logger.info("val split, do not use random augmentation.")
else:
self.patch_resize_transform = create_transform(
input_size=patch_image_size,
is_training=True,
color_jitter=0.4,
auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
re_prob=0.25,
re_mode='pixel',
re_count=1,
mean=mean,
std=std,
)
self.patch_resize_transform = transforms.Compose(functools.reduce(lambda x, y:x + y, [
[lambda image: image.convert("RGB"),],
self.patch_resize_transform.transforms[:2],
[self.patch_resize_transform.transforms[2]],
[RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), ],
self.patch_resize_transform.transforms[3:],
]))
logger.info("train split, use random augmentation.")
def __getitem__(self, index):
image, label_name = self.dataset[index]
image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
patch_image = self.patch_resize_transform(image)
patch_mask = torch.tensor([True])
src_item = self.encode_text(' what does the image describe?')
tgt_item = self.encode_text(" {}".format(label_name))
ref_dict = {label_name: 1.0}
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])
example = {
"id": index,
"source": src_item,
"patch_image": patch_image,
"patch_mask": patch_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
"ref_dict": ref_dict,
}
if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
for i in range(len(prev_output_item)):
constraint_prefix_token = prev_output_item[:i+1].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
example["constraint_mask"] = constraint_mask
return example
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data of the task
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment