# Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from kmeng01/rome. from typing import Any, Dict, List, Tuple import numpy as np import torch from modelscope import AutoTokenizer from swift.utils.logger import get_logger from .nethook import TraceDict, set_requires_grad from .repr_tools import get_reprs_at_idxs, get_reprs_at_word_tokens, get_words_idxs_in_templates from .rome_hparams import ROMEHyperParams logger = get_logger() def compute_v(model: torch.nn.Module, tokenizer: AutoTokenizer, request: Dict, hparams: ROMEHyperParams, layer: int, left_vector: torch.Tensor, context_templates: List[str], batch_first: bool = True) -> torch.Tensor: """ Computes the value (right) vector for the rank-1 update. Runs a simple optimization procedure. """ logger.info('Computing right vector (v)') # Compile list of rewriting and KL x/y pairs rewriting_prompts, kl_prompts = [ context.format(request['prompt']) + request['target'] for context in context_templates ], ['{} is a', '{}是一个'] all_prompts = rewriting_prompts + kl_prompts input_tok = tokenizer( [prompt.format(request['subject']) for prompt in all_prompts], return_tensors='pt', padding=True, return_token_type_ids=False, ).to(model.device) # Compute rewriting targets rewriting_targets = torch.tensor( -100, device=model.device).repeat(len(rewriting_prompts), *input_tok['input_ids'].shape[1:]) prompt = context_templates[0].format(request['prompt']) prompt_full = prompt + request['target'] target_len = len(tokenizer.tokenize(prompt_full)) - len(tokenizer.tokenize(prompt)) for i in range(len(rewriting_prompts)): rewriting_targets[i, -target_len - 1:-1] = input_tok['input_ids'][i, -target_len:].clone() # Compute indices of the tokens where the fact is looked up lookup_idxs = [ find_fact_lookup_idx(prompt, request['subject'], tokenizer, hparams.fact_token, verbose=(i == 0)) for i, prompt in enumerate(all_prompts) ] # Finalize rewrite and loss layers logger.info(f'Rewrite layer is {layer}') # Set up an optimization over a latent vector that, when output at the # rewrite layer, i.e. hypothesized fact lookup location, will induce the # target token to be predicted at the final layer. hidden_size = model.config.n_embd if hasattr(model.config, 'n_embed') else model.config.hidden_size delta = torch.zeros((hidden_size, ), requires_grad=True, device=model.device) target_init, kl_distr_init = None, None # Inserts new "delta" variable at the appropriate part of the computation def edit_output_fn(cur_out, cur_layer): nonlocal target_init # Store initial value of the vector of interest if target_init is None: logger.info('Recording initial value of v*') # Initial value is recorded for the clean sentence target_init = cur_out[0, lookup_idxs[0]].detach().clone() for i, idx in enumerate(lookup_idxs): if batch_first: cur_out[i, idx, :] += delta else: cur_out[idx, i, :] += delta return cur_out # Optimizer opt = torch.optim.Adam([delta], lr=hparams.v_lr) set_requires_grad(False, model) # Execute optimization for it in range(hparams.v_num_grad_steps): opt.zero_grad() # Forward propagation with TraceDict( module=model, layers=[ hparams.mlp_module_tmp.format(layer), ], retain_input=False, retain_output=True, edit_output=edit_output_fn, ) as _: logits = model(**input_tok).logits # Compute distribution for KL divergence kl_logits = torch.stack( [logits[i - len(kl_prompts), idx, :] for i, idx in enumerate(lookup_idxs[-len(kl_prompts):])], dim=0, ) kl_log_probs = torch.nn.functional.log_softmax(kl_logits, dim=1) if kl_distr_init is None: kl_distr_init = kl_log_probs.detach().clone() # Compute loss on rewriting targets log_probs = torch.log_softmax(logits, dim=2) loss = torch.gather( log_probs, 2, torch.where(rewriting_targets != -100, rewriting_targets, 0).unsqueeze(2), ).squeeze(2) mask = (rewriting_targets != -100).float() # Aggregate total losses nll_loss_each = -(loss * mask).sum(1) / target_len nll_loss = nll_loss_each.mean() kl_loss = hparams.kl_factor * torch.nn.functional.kl_div( kl_distr_init, kl_log_probs, log_target=True, reduction='batchmean') weight_decay = hparams.v_weight_decay * (torch.norm(delta) / torch.norm(target_init)**2) # weight_decay = hparams.v_weight_decay * torch.norm(delta) ** 2 loss = nll_loss + kl_loss + weight_decay logger.info(f'loss {np.round(loss.item(), 3)} = {np.round(nll_loss.item(), 3)} + ' f'{np.round(kl_loss.item(), 3)} + {np.round(weight_decay.item(), 3)} ' f"avg prob of [{request['target']}] " f'{torch.exp(-nll_loss_each).mean().item()}') if loss < 5e-2: break if it == hparams.v_num_grad_steps - 1: break # Backpropagate loss.backward() opt.step() # Project within L2 ball max_norm = hparams.clamp_norm_factor * target_init.norm() if delta.norm() > max_norm: with torch.no_grad(): delta[...] = delta * max_norm / delta.norm() target = target_init + delta # Retrieve cur_input, the current input to the 2nd MLP layer, and # cur_output, the original output of the 2nd MLP layer. cur_input, cur_output = get_module_input_output_at_word( model, tokenizer, layer, context_template=request['prompt'], word=request['subject'], module_template=hparams.rewrite_module_tmp, fact_token_strategy=hparams.fact_token, batch_first=batch_first) # Solving the linear system to compute the right vector right_vector = (target - cur_output) / torch.dot(cur_input, left_vector) logger.info(f'Delta norm: {(target - cur_output).norm().item()}') logger.info(f'Change in target norm: {target_init.norm().item()} to {target.norm().item()} => ' f'{(target.norm() - target_init.norm()).item()}') logger.info(f'Division Factor: {torch.dot(cur_input, left_vector).item()}') logger.info(f'Right vector norm: {right_vector.norm()}') return right_vector def get_module_input_output_at_word(model: torch.nn.Module, tok: Any, layer: int, context_template: str, word: str, module_template: str, fact_token_strategy: str, batch_first: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: """ Retrieves detached representations for a word at the input and output of a particular layer module. """ word_repr_args = dict( model=model, tokenizer=tok, layer=layer, module_template=module_template, batch_first=batch_first) if 'subject_' in fact_token_strategy and fact_token_strategy.index('subject_') == 0: subtoken = fact_token_strategy[len('subject_'):] l_input, l_output = get_reprs_at_word_tokens( track='both', subtoken=subtoken, context_templates=[context_template], words=[word], **word_repr_args, ) elif fact_token_strategy == 'last': l_input, l_output = get_reprs_at_idxs( track='both', contexts=[context_template.format(word)], idxs=[[-1]], **word_repr_args, ) else: raise ValueError(f'fact_token={fact_token_strategy} not recognized') l_input, l_output = l_input[0], l_output[0] return l_input.detach(), l_output.detach() def find_fact_lookup_idx( prompt: str, subject: str, tok: Any, fact_token_strategy: str, verbose=True, ) -> int: """ Computes hypothesized fact lookup index given a sentence and subject. """ if fact_token_strategy == 'last': ret = -1 elif ('subject_' in fact_token_strategy and fact_token_strategy.index('subject_') == 0): ret = get_words_idxs_in_templates( tok, context_templates=[prompt], words=[subject], subtoken=fact_token_strategy[len('subject_'):], )[0][0] else: raise ValueError(f'fact_token={fact_token_strategy} not recognized') sentence = prompt.format(subject) if verbose: logger.info( f'Lookup index found: {ret} | Sentence: {sentence} | Token:' + tok.decode(tok(sentence)['input_ids'][ret]), ) return ret