# Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from kmeng01/rome. """ Contains utilities for extracting token representations and indices from string templates. Used in computing the left and right vectors for ROME. """ from typing import Any, Callable, List, Tuple, Union import torch from modelscope import AutoTokenizer from .nethook import Trace def get_reprs_at_word_tokens( model: torch.nn.Module, tokenizer: Any, context_templates: List[str], words: List[str], layer: int, module_template: str, subtoken: str, track: str = 'in', batch_first: bool = True, ) -> torch.Tensor: """ Retrieves the last token representation of `word` in `context_template` when `word` is substituted into `context_template`. See `get_last_word_idx_in_template` for more details. """ idxs = get_words_idxs_in_templates(tokenizer, context_templates, words, subtoken) return get_reprs_at_idxs( model, tokenizer, [context_templates[i].format(words[i]) for i in range(len(words))], idxs, layer, module_template, track, batch_first, ) def get_words_idxs_in_templates(tokenizer: AutoTokenizer, context_templates: List[str], words: List[str], subtoken: str) -> List: """ Given list of template strings, each with *one* format specifier (e.g. "{} plays basketball"), and words to be substituted into the template, computes the post-tokenization index of their last tokens. """ assert all(tmp.count('{}') == 1 for tmp in context_templates), 'We currently do not support multiple fill-ins for context' # Compute prefixes and suffixes of the tokenized context fill_idxs = [tmp.index('{}') for tmp in context_templates] prefixes, suffixes = [tmp[:fill_idxs[i]] for i, tmp in enumerate(context_templates) ], [tmp[fill_idxs[i] + 2:] for i, tmp in enumerate(context_templates)] lens = [] for prefix, word, suffix in zip(prefixes, words, suffixes): prefix_token = tokenizer.encode(prefix) prefix_word_token = tokenizer.encode(prefix + word) prefix_word_suffix_token = tokenizer.encode(prefix + word + suffix) suffix_len = len(prefix_word_suffix_token) - len(prefix_word_token) # Compute indices of last tokens if subtoken == 'last' or subtoken == 'first_after_last': lens.append([ len(prefix_word_token) - (1 if subtoken == 'last' or suffix_len == 0 else 0) - len(prefix_word_suffix_token) ]) elif subtoken == 'first': lens.append([len(prefix_token) - len(prefix_word_suffix_token)]) else: raise ValueError(f'Unknown subtoken type: {subtoken}') return lens def get_reprs_at_idxs( model: torch.nn.Module, tokenizer: Callable, contexts: List[str], idxs: List[List[int]], layer: int, module_template: str, track: str = 'in', batch_first: bool = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Runs input through model and returns averaged representations of the tokens at each index in `idxs`. """ def _batch(n): for i in range(0, len(contexts), n): yield contexts[i:i + n], idxs[i:i + n] assert track in {'in', 'out', 'both'} both = track == 'both' tin, tout = ( (track == 'in' or both), (track == 'out' or both), ) module_name = module_template.format(layer) to_return = {'in': [], 'out': []} def _process(cur_repr, batch_idxs, key): nonlocal to_return cur_repr = cur_repr[0] if isinstance(cur_repr, tuple) else cur_repr if not batch_first: cur_repr = cur_repr.transpose(0, 1) for i, idx_list in enumerate(batch_idxs): to_return[key].append(cur_repr[i][idx_list].mean(0)) for batch_contexts, batch_idxs in _batch(n=512): contexts_tok = tokenizer( batch_contexts, padding=True, return_token_type_ids=False, return_tensors='pt').to(next(model.parameters()).device) with torch.no_grad(): with Trace( module=model, layer=module_name, retain_input=tin, retain_output=tout, ) as tr: model(**contexts_tok) if tin: _process(tr.input, batch_idxs, 'in') if tout: _process(tr.output, batch_idxs, 'out') to_return = {k: torch.stack(v, 0) for k, v in to_return.items() if len(v) > 0} if len(to_return) == 1: return to_return['in'] if tin else to_return['out'] else: return to_return['in'], to_return['out']