Commit f9b1a89a authored by HHL's avatar HHL
Browse files

v

parent 60e27226
from typing import Any, Dict, Union
import torch
from torch.utils.data.dataloader import DataLoader
from transformers import Trainer
class FunsdTrainer(Trainer):
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
for k, v in inputs.items():
if hasattr(v, "to") and hasattr(v, "device"):
inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past
return inputs
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from transformers.utils import logging
import torch
logger = logging.get_logger(__name__)
class NanDetector:
"""
Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name
"""
def __init__(self, model, forward=True, backward=True):
self.bhooks = []
self.fhooks = []
self.forward = forward
self.backward = backward
self.named_parameters = list(model.named_parameters())
self.reset()
for name, mod in model.named_modules():
mod.__module_name = name
self.add_hooks(mod)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
# Dump out all model gnorms to enable better debugging
norm = {}
gradients = {}
for name, param in self.named_parameters:
if param.grad is not None:
grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32)
norm[name] = grad_norm.item()
if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any():
gradients[name] = param.grad.data
if len(gradients) > 0:
logger.info("Detected nan/inf grad norm, dumping norms...")
logger.info(f"norms: {norm}")
logger.info(f"gradients: {gradients}")
self.close()
def add_hooks(self, module):
if self.forward:
self.fhooks.append(module.register_forward_hook(self.fhook_fn))
if self.backward:
self.bhooks.append(module.register_backward_hook(self.bhook_fn))
def reset(self):
self.has_printed_f = False
self.has_printed_b = False
def _detect(self, tensor, name, backward):
err = None
if (
torch.is_floating_point(tensor)
# single value tensors (like the loss) will not provide much info
and tensor.numel() >= 2
):
with torch.no_grad():
if torch.isnan(tensor).any():
err = "NaN"
elif torch.isinf(tensor).any():
err = "Inf"
if err is not None:
err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}"
return err
def _apply(self, module, inp, x, backward):
if torch.is_tensor(x):
if isinstance(inp, tuple) and len(inp) > 0:
inp = inp[0]
err = self._detect(x, module.__module_name, backward)
if err is not None:
if torch.is_tensor(inp) and not backward:
err += (
f" input max: {inp.max().item()}, input min: {inp.min().item()}"
)
has_printed_attr = "has_printed_b" if backward else "has_printed_f"
logger.warning(err)
setattr(self, has_printed_attr, True)
elif isinstance(x, dict):
for v in x.values():
self._apply(module, inp, v, backward)
elif isinstance(x, list) or isinstance(x, tuple):
for v in x:
self._apply(module, inp, v, backward)
def fhook_fn(self, module, inp, output):
if not self.has_printed_f:
self._apply(module, inp, output, backward=False)
def bhook_fn(self, module, inp, output):
if not self.has_printed_b:
self._apply(module, inp, output, backward=True)
def close(self):
for hook in self.fhooks + self.bhooks:
hook.remove()
import torch
from torch import nn
from libs.configs.default import counter
from transformers import Trainer
from typing import Any, Dict, Union
from torch.utils.data.distributed import DistributedSampler
from libs.utils.comm import distributed, get_rank, get_world_size
from transformers.trainer import *
from .nan_detector import NanDetector
class PreTrainer(Trainer):
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
for k, v in inputs.items():
if hasattr(v, "to") and hasattr(v, "device"):
inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past
return inputs
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
if distributed():
sampler = DistributedSampler(self.train_dataset, get_world_size(), get_rank(), True)
dataloader = torch.utils.data.DataLoader(
self.train_dataset,
sampler=sampler,
num_workers=self.args.dataloader_num_workers,
batch_size=self.args.train_batch_size,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
pin_memory=self.args.dataloader_pin_memory,
)
else:
dataloader = torch.utils.data.DataLoader(
self.train_dataset,
num_workers=self.args.dataloader_num_workers,
batch_size=self.args.train_batch_size,
collate_fn=self.data_collator,
shuffle=True,
drop_last=self.args.dataloader_drop_last,
pin_memory=self.args.dataloader_pin_memory
)
return dataloader
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
# reset tr_loss to zero
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = round(self._get_learning_rate(),10)
logs["cuda_max_memory"] = int(torch.cuda.max_memory_allocated()/1024/1024)
logs = dict(logs, **counter.dict_mean(sync=False))
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.log(logs)
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
\ No newline at end of file
import collections
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers.trainer_utils import EvalPrediction, PredictionOutput, speed_metrics
from transformers.utils import logging
from .funsd_trainer import FunsdTrainer
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
logger = logging.get_logger(__name__)
class XfunSerTrainer(FunsdTrainer):
pass
class XfunReTrainer(FunsdTrainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.label_names.append("relations")
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
labels = tuple(inputs.get(name) for name in self.label_names)
return outputs, labels
def prediction_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Works both with or without labels.
"""
if not isinstance(dataloader.dataset, collections.abc.Sized):
raise ValueError("dataset must implement __len__")
prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
if self.args.deepspeed and not self.args.do_train:
# no harm, but flagging to the user that deepspeed config is ignored for eval
# flagging only for when --do_train wasn't passed as only then it's redundant
logger.info("Detected the deepspeed argument but it will not be used for evaluation")
model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, half it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)
batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader)
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", num_examples)
logger.info(" Batch size = %d", batch_size)
model.eval()
self.callback_handler.eval_dataloader = dataloader
re_labels = None
pred_relations = None
entities = None
for step, inputs in enumerate(dataloader):
outputs, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
re_labels = labels[1] if re_labels is None else re_labels + labels[1]
pred_relations = (
outputs.pred_relations if pred_relations is None else pred_relations + outputs.pred_relations
)
entities = outputs.entities if entities is None else entities + outputs.entities
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
gt_relations = []
for b in range(len(re_labels)):
rel_sent = []
for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (entities[b]["start"][rel["head_id"]], entities[b]["end"][rel["head_id"]])
rel["head_type"] = entities[b]["label"][rel["head_id"]]
rel["tail_id"] = tail
rel["tail"] = (entities[b]["start"][rel["tail_id"]], entities[b]["end"][rel["tail_id"]])
rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = self.compute_metrics(EvalPrediction(predictions=pred_relations, label_ids=gt_relations))
re_metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
"f1": re_metrics["ALL"]["f1"],
}
re_metrics[f"{metric_key_prefix}_loss"] = outputs.loss.mean().item()
metrics = {}
# # Prefix all keys with metric_key_prefix + '_'
for key in list(re_metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = re_metrics.pop(key)
else:
metrics[f"{key}"] = re_metrics.pop(key)
return metrics
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init :obj:`compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
self.args.local_rank = -1
eval_dataloader = self.get_eval_dataloader(eval_dataset)
self.args.local_rank = torch.distributed.get_rank()
start_time = time.time()
metrics = self.prediction_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
self.log(metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import torch
from transformers.file_utils import ModelOutput
@dataclass
class ReOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
entities: Optional[Dict] = None
relations: Optional[Dict] = None
pred_relations: Optional[Dict] = None
from . import default
import importlib
class CFG:
def __init__(self):
self.__dict__['cfg'] = None
def __getattr__(self, name):
return getattr(self.__dict__['cfg'], name)
def __setattr__(self, name, val):
setattr(self.__dict__['cfg'], name, val)
cfg = CFG()
cfg.__dict__['cfg'] = default
def setup_config(cfg_name):
global cfg
module_name = 'libs.configs.' + cfg_name
cfg_module = importlib.import_module(module_name)
cfg.__dict__['cfg'] = cfg_module
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment