# Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. from collections import defaultdict import copy import importlib.util import math import os import sys from typing import Callable, List import warnings import torch import torch.nn.functional as F from fairseq.modules import gelu, gelu_accurate def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): from fairseq import checkpoint_utils deprecation_warning( 'utils.load_ensemble_for_inference is deprecated. ' 'Please use checkpoint_utils.load_model_ensemble instead.' ) return checkpoint_utils.load_model_ensemble( filenames, arg_overrides=model_arg_overrides, task=task, ) def move_to_cuda(sample): if len(sample) == 0: return {} def _move_to_cuda(maybe_tensor): if torch.is_tensor(maybe_tensor): return maybe_tensor.cuda() elif isinstance(maybe_tensor, dict): return { key: _move_to_cuda(value) for key, value in maybe_tensor.items() } elif isinstance(maybe_tensor, list): return [_move_to_cuda(x) for x in maybe_tensor] else: return maybe_tensor return _move_to_cuda(sample) INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) def _get_full_incremental_state_key(module_instance, key): module_name = module_instance.__class__.__name__ # assign a unique ID to each module instance, so that incremental state is # not shared across module instances if not hasattr(module_instance, '_fairseq_instance_id'): INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key) def get_incremental_state(module, incremental_state, key): """Helper for getting incremental state for an nn.Module.""" full_key = _get_full_incremental_state_key(module, key) if incremental_state is None or full_key not in incremental_state: return None return incremental_state[full_key] def set_incremental_state(module, incremental_state, key, value): """Helper for setting incremental state for an nn.Module.""" if incremental_state is not None: full_key = _get_full_incremental_state_key(module, key) incremental_state[full_key] = value def load_align_dict(replace_unk): if replace_unk is None: align_dict = None elif isinstance(replace_unk, str): # Load alignment dictionary for unknown word replacement if it was passed as an argument. align_dict = {} with open(replace_unk, 'r') as f: for line in f: cols = line.split() align_dict[cols[0]] = cols[1] else: # No alignment dictionary provided but we still want to perform unknown word replacement by copying the # original source word. align_dict = {} return align_dict def print_embed_overlap(embed_dict, vocab_dict): embed_keys = set(embed_dict.keys()) vocab_keys = set(vocab_dict.symbols) overlap = len(embed_keys & vocab_keys) print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict))) def parse_embedding(embed_path): """Parse embedding text file into a dictionary of word and embedding tensors. The first line can have vocabulary size and dimension. The following lines should contain word and embedding separated by spaces. Example: 2 5 the -0.0230 -0.0264 0.0287 0.0171 0.1403 at -0.0395 -0.1286 0.0275 0.0254 -0.0932 """ embed_dict = {} with open(embed_path) as f_embed: next(f_embed) # skip header for line in f_embed: pieces = line.rstrip().split(" ") embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]]) return embed_dict def load_embedding(embed_dict, vocab, embedding): for idx in range(len(vocab)): token = vocab[idx] if token in embed_dict: embedding.weight.data[idx] = embed_dict[token] return embedding def replace_unk(hypo_str, src_str, alignment, align_dict, unk): from fairseq import tokenizer # Tokens are strings here hypo_tokens = tokenizer.tokenize_line(hypo_str) # TODO: Very rare cases where the replacement is '' should be handled gracefully src_tokens = tokenizer.tokenize_line(src_str) + [''] for i, ht in enumerate(hypo_tokens): if ht == unk: src_token = src_tokens[alignment[i]] # Either take the corresponding value in the aligned dictionary or just copy the original value. hypo_tokens[i] = align_dict.get(src_token, src_token) return ' '.join(hypo_tokens) def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe): from fairseq import tokenizer hypo_str = tgt_dict.string(hypo_tokens, remove_bpe) if align_dict is not None: hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string()) if align_dict is not None or remove_bpe is not None: # Convert back to tokens for evaluating with unk replacement or without BPE # Note that the dictionary can be modified inside the method. hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True) return hypo_tokens, hypo_str, alignment def make_positions(tensor, padding_idx, onnx_trace=False): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. """ mask = tensor.ne(padding_idx).long() return torch.cumsum(mask, dim=1) * mask + padding_idx def strip_pad(tensor, pad): return tensor[tensor.ne(pad)] def buffered_arange(max): if not hasattr(buffered_arange, 'buf'): buffered_arange.buf = torch.LongTensor() if max > buffered_arange.buf.numel(): torch.arange(max, out=buffered_arange.buf) return buffered_arange.buf[:max] def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False): assert right_to_left ^ left_to_right pad_mask = src_tokens.eq(padding_idx) if not pad_mask.any(): # no padding, return early return src_tokens if left_to_right and not pad_mask[:, 0].any(): # already right padded return src_tokens if right_to_left and not pad_mask[:, -1].any(): # already left padded return src_tokens max_len = src_tokens.size(1) range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens) num_pads = pad_mask.long().sum(dim=1, keepdim=True) if right_to_left: index = torch.remainder(range - num_pads, max_len) else: index = torch.remainder(range + num_pads, max_len) return src_tokens.gather(1, index) def item(tensor): if hasattr(tensor, 'item'): return tensor.item() if hasattr(tensor, '__getitem__'): return tensor[0] return tensor def clip_grad_norm_(tensor, max_norm): grad_norm = item(torch.norm(tensor)) if grad_norm > max_norm > 0: clip_coef = max_norm / (grad_norm + 1e-6) tensor.mul_(clip_coef) return grad_norm def fill_with_neg_inf(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(float('-inf')).type_as(t) def resolve_max_positions(*args): """Resolve max position constraints from multiple sources.""" def map_value_update(d1, d2): updated_value = copy.deepcopy(d1) for key in d2: if key not in updated_value: updated_value[key] = d2[key] else: updated_value[key] = min(d1[key], d2[key]) return updated_value def nullsafe_min(l): minim = None for item in l: if minim is None: minim = item elif item is not None and item < minim: minim = item return minim max_positions = None for arg in args: if max_positions is None: max_positions = arg elif arg is not None: if isinstance(arg, float) or isinstance(arg, int): max_positions = min(max_positions, arg) elif isinstance(arg, dict): max_positions = map_value_update(max_positions, arg) else: max_positions = tuple( map(nullsafe_min, zip(max_positions, arg)) ) return max_positions def import_user_module(args): module_path = getattr(args, 'user_dir', None) if module_path is not None: module_path = os.path.abspath(args.user_dir) module_parent, module_name = os.path.split(module_path) if module_name not in sys.modules: sys.path.insert(0, module_parent) importlib.import_module(module_name) sys.path.pop(0) def softmax(x, dim, onnx_trace=False): if onnx_trace: return F.softmax(x.float(), dim=dim) else: return F.softmax(x, dim=dim, dtype=torch.float32) def log_softmax(x, dim, onnx_trace=False): if onnx_trace: return F.log_softmax(x.float(), dim=dim) else: return F.log_softmax(x, dim=dim, dtype=torch.float32) def get_perplexity(loss): try: return '{:.2f}'.format(math.pow(2, loss)) except OverflowError: return float('inf') def deprecation_warning(message, stacklevel=3): # don't use DeprecationWarning, since it's ignored by default warnings.warn(message, stacklevel=stacklevel) def get_activation_fn(activation: str) -> Callable: """ Returns the activation function corresponding to `activation` """ if activation == 'relu': return F.relu elif activation == 'gelu': return gelu elif activation == 'gelu_fast': deprecation_warning('--activation-fn=gelu_fast has been renamed to gelu_accurate') return gelu_accurate elif activation == 'gelu_accurate': return gelu_accurate elif activation == 'tanh': return F.tanh else: raise RuntimeError(f"--activation-fn {activation} not supported") def get_available_activation_fns() -> List: return [ 'relu', 'gelu', 'gelu_fast', # deprecated 'gelu_accurate', 'tanh', ]