Commit 72f5785f authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
Pipeline #505 canceled with stages
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
TODO (huxu): a general fairseq criterion for all your pre-defined losses.
"""
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.logging import metrics
@register_criterion("mmloss")
class MMCriterion(FairseqCriterion):
def __init__(self, task):
super().__init__(task)
# TODO (huxu): wrap forward call of loss_fn and eval_fn into task.
self.mmtask = task.mmtask
def forward(self, model, sample):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
outputs = self.mmtask(model, sample)
loss, loss_scalar, max_len, batch_size, sample_size = (
outputs["loss"],
outputs["loss_scalar"],
outputs["max_len"],
outputs["batch_size"],
outputs["sample_size"],
)
logging_output = {
"loss": loss_scalar,
"ntokens": max_len * batch_size, # dummy report.
"nsentences": batch_size, # dummy report.
"sample_size": sample_size,
}
return loss, 1, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
"""since we use NCE, our actual batch_size is 1 per GPU.
Then we take the mean of each worker."""
loss_sum = sum(log.get("loss", 0.0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
from torch import nn
class Loss(object):
def __call__(self, *args, **kwargs):
raise NotImplementedError
# Dummy Loss for testing.
class DummyLoss(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
class DummyK400Loss(Loss):
"""dummy k400 loss for MViT."""
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(
logits, torch.randint(0, 400, (logits.size(0),), device=logits.device))
class CrossEntropy(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
class ArgmaxCrossEntropy(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets.argmax(dim=1))
class BCE(Loss):
def __init__(self):
self.loss = nn.BCEWithLogitsLoss()
def __call__(self, logits, targets, **kwargs):
targets = targets.squeeze(0)
return self.loss(logits, targets)
class NLGLoss(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, text_label, **kwargs):
targets = text_label[text_label != -100]
return self.loss(logits, targets)
class MSE(Loss):
def __init__(self):
self.loss = nn.MSELoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
class L1(Loss):
def __init__(self):
self.loss = nn.L1Loss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
class SmoothL1(Loss):
def __init__(self):
self.loss = nn.SmoothL1Loss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
This diff is collapsed.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .mmfusion import *
from .transformermodel import *
from .mmfusionnlg import *
try:
from .fairseqmmmodel import *
except ImportError:
pass
try:
from .expmmfusion import *
except ImportError:
pass
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.models import (
BaseFairseqModel,
register_model,
register_model_architecture
)
@register_model("mmmodel")
class FairseqMMModel(BaseFairseqModel):
"""a fairseq wrapper of model built by `task`."""
@classmethod
def build_model(cls, args, task):
return FairseqMMModel(task.mmtask.model)
def __init__(self, mmmodel):
super().__init__()
self.mmmodel = mmmodel
def forward(self, *args, **kwargs):
return self.mmmodel(*args, **kwargs)
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
keys_to_delete = []
for key in state_dict:
if key not in self.state_dict():
keys_to_delete.append(key)
for key in keys_to_delete:
print("[INFO]", key, "not used anymore.")
del state_dict[key]
# copy any newly defined parameters.
for key in self.state_dict():
if key not in state_dict:
print("[INFO] adding", key)
state_dict[key] = self.state_dict()[key]
# a dummy arch, we config the model.
@register_model_architecture("mmmodel", "mmarch")
def mmarch(args):
pass
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment