Unverified Commit 7be00f3b authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Sae steered (#2750)



* Can evaluate sae_steered_beta models

* Can evaluate sae_steered_beta models add

* Ran pre-comit and pytest

* Add larger description for feature in dog_steer.csv

* Cache the repeated SAE instead of making new every time.

* Types of hook in InterventionModel

* Other steer functions. Deprecated InterventionModel instantiation from hooks list; now only csv

* Other hook functions

---------
Co-authored-by: default avatarAMindToThink <cs29824@sting-vm-1.cs.northwestern.edu>
Co-authored-by: default avatarMatthew Khoriaty <61801493+AMindToThink@users.noreply.github.com>
parent a87fe425
latent_idx,steering_coefficient,sae_release,sae_id,description
12082, 240.0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical,this feature has been found on neuronpedia to make the model talk about dogs and obedience
......@@ -13,6 +13,7 @@ from . import (
openai_completions,
optimum_ipex,
optimum_lm,
sae_steered_beta,
sglang_causallms,
textsynth,
vllm_causallms,
......
import einops
# Andre was working on Matthew's folders, and Matthew didn't want to edit the same doc at the same time.
def steering_hook_projection(
activations,#: Float[Tensor], # Float[Tensor, "batch pos d_in"], Either jaxtyping or lm-evaluation-harness' precommit git script hate a type hint here.
hook: HookPoint,
sae: SAE,
latent_idx: int,
steering_coefficient: float,
) -> Tensor:
"""
Steers the model by finding the projection of each activations,
along the specified feature and adding some multiple of that projection to the activation.
"""
bad_feature = sae.W_dec[latent_idx] # batch, pos, d_in @ d_in, d_embedding -> batch, pos, d_embedding
dot_products = einops.einsum(activations, bad_feature, "batch pos d_embedding, d_embedding -> batch pos")
dot_products /= bad_feature.norm()
# Calculate the projection of activations onto the feature direction
projection = einops.einsum(
dot_products,
bad_feature,
"batch pos, d_embedding -> batch pos d_embedding"
)
# Add scaled projection to original activations
return activations + steering_coefficient * projection
\ No newline at end of file
import torch
def batch_vector_projection(vectors: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Projects each vector in a batch onto a target vector.
Args:
vectors: Tensor of shape (b, p, d) where:
b is the batch size
p is the number of vectors per batch
d is the dimension of each vector
target: Tensor of shape (d,) - the vector to project onto
Returns:
Tensor of shape (b, p, d) containing the projected vectors
Example:
b, p, d = 32, 10, 3 # batch of 32, 10 vectors each, in 3D
vectors = torch.randn(b, p, d)
target = torch.randn(d)
projections = batch_vector_projection(vectors, target)
"""
# Ensure target is unit vector
target = torch.nn.functional.normalize(target, dim=0)
# Reshape target to (1, 1, d) for broadcasting
target_reshaped = target.view(1, 1, -1)
# Compute dot product between each vector and target
# Result shape: (b, p, 1)
dot_products = torch.sum(vectors * target_reshaped, dim=-1, keepdim=True)
# Project each vector onto target
# Multiply dot products by target vector
# Result shape: (b, p, d)
projections = dot_products * target_reshaped
return projections, dot_products
# Test function
if __name__ == "__main__":
# Create sample data
batch_size, vectors_per_batch, dim = 2, 3, 4
vectors = torch.randn(batch_size, vectors_per_batch, dim)
target = torch.randn(dim)
# Compute projections
projected, dot_products = batch_vector_projection(vectors, target)
_, zero_dot_products = batch_vector_projection(vectors - projected, target)
assert torch.allclose(zero_dot_products, torch.zeros_like(zero_dot_products), atol=1e-6)
print("Without proj, close to zero")
# Verify shapes
print(f"Input shape: {vectors.shape}")
print(f"Target shape: {target.shape}")
print(f"Output shape: {projected.shape}")
\ No newline at end of file
"""Credit: contributed by https://github.com/AMindToThink aka Matthew Khoriaty of Northwestern University."""
from functools import partial
import torch
from jaxtyping import Float
from sae_lens import SAE, HookedSAETransformer
from torch import Tensor
from transformer_lens import loading_from_pretrained
from transformer_lens.hook_points import HookPoint
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
def steering_hook_add_scaled_one_hot(
activations,#: Float[Tensor], # Float[Tensor, "batch pos d_in"], Either jaxtyping or lm-evaluation-harness' precommit git script hate a type hint here.
hook: HookPoint,
sae: SAE,
latent_idx: int,
steering_coefficient: float,
) -> Tensor:
"""
Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
sequence positions.
"""
return activations + steering_coefficient * sae.W_dec[latent_idx]
# def steering_hook_clamp(
# activations,#: Float[Tensor], # Float[Tensor, "batch pos d_in"], Either jaxtyping or lm-evaluation-harness' precommit git script hate a type hint here.
# hook: HookPoint,
# sae: SAE,
# latent_idx: int,
# steering_coefficient: float,
# ) -> Tensor:
# """
# Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
# sequence positions.
# """
# raise NotImplemented
# z = sae.encode(activations)
# z[latent_idx] = steering_coefficient
# return sae.decode(activations)
# return activations + steering_coefficient * sae.W_dec[latent_idx]
def clamp_sae_feature(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:float) -> Tensor:
"""Clamps a specific latent feature in the SAE activations to a fixed value.
Args:
sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
hook (HookPoint): The transformer-lens hook point
latent_idx (int): Index of the latent feature to clamp
value (float): Value to clamp the feature to
Returns:
Tensor: The modified SAE activations with the specified feature clamped
"""
sae_acts[:, :, latent_idx] = value
return sae_acts
def clamp_original(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:float) -> Tensor:
"""Clamps a specific latent feature in the SAE activations to a fixed value.
Args:
sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
hook (HookPoint): The transformer-lens hook point
latent_idx (int): Index of the latent feature to clamp
value (float): Value to clamp the feature to
Returns:
Tensor: The modified SAE activations with the specified feature clamped
"""
mask = sae_acts[:, :, latent_idx] > 0 # Create a boolean mask where values are greater than 0
sae_acts[:, :, latent_idx][mask] = value # Replace values conditionally
return sae_acts
def print_sae_acts(sae_acts:Tensor, hook:HookPoint) -> Tensor:
"""Clamps a specific latent feature in the SAE activations to a fixed value.
Args:
sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
hook (HookPoint): The transformer-lens hook point
latent_idx (int): Index of the latent feature to clamp
value (float): Value to clamp the feature to
Returns:
Tensor: The modified SAE activations with the specified feature clamped
"""
print(40*"----")
print(f"This is the latent activations of {hook.name}")
print(sae_acts.shape)
print(torch.all(sae_acts > 0))
return sae_acts
string_to_steering_function_dict : dict = {'add':steering_hook_add_scaled_one_hot, 'clamp':clamp_sae_feature}
class InterventionModel(HookedSAETransformer): # Replace with the specific model class
def __init__(self, base_name: str, device: str = "cuda:0", model=None):
trueconfig = loading_from_pretrained.get_pretrained_model_config(
base_name, device=device
)
super().__init__(trueconfig)
self.model = model or HookedSAETransformer.from_pretrained(base_name, device=device)
self.model.use_error_term = True
self.model.eval()
self.device = device # Add device attribute
self.to(device) # Ensure model is on the correct device
@classmethod
def from_csv(
cls, csv_path: str, base_name: str, device: str = "cuda:0"
) -> "InterventionModel":
"""
Create an InterventionModel from a CSV file containing steering configurations.
Expected CSV format:
index, coefficient, sae_release, sae_id, description
12082, 240.0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical, increase dogs
...
Args:
csv_path: Path to the CSV file containing steering configurations
device: Device to place the model on
Returns:
InterventionModel with configured steering hooks
"""
import pandas as pd
model = HookedSAETransformer.from_pretrained(base_name, device=device)
# Read steering configurations
df = pd.read_csv(csv_path)
# Create hooks for each row in the CSV
sae_cache = {}
hooks = []
def get_sae(sae_release, sae_id):
cache_key = (sae_release, sae_id)
if cache_key not in sae_cache:
sae_cache[cache_key] = SAE.from_pretrained(
sae_release, sae_id, device=str(device)
)[0]
return sae_cache[cache_key]
for _, row in df.iterrows():
sae_release = row["sae_release"]
sae_id = row["sae_id"]
latent_idx = int(row["latent_idx"])
steering_coefficient = float(row["steering_coefficient"])
sae = get_sae(sae_release=sae_release, sae_id=sae_id)
sae.use_error_term = True
sae.eval()
model.add_sae(sae)
hook_action = row.get("hook_action", "add")
if hook_action == "add":
hook_name = f"{sae.cfg.hook_name}.hook_sae_input" # we aren't actually putting the input through the model
hook = partial(steering_hook_add_scaled_one_hot,
sae=sae,
latent_idx=latent_idx,
steering_coefficient=steering_coefficient,
)
model.add_hook(hook_name, hook)
elif hook_action == "clamp":
sae.add_hook("hook_sae_acts_post", partial(clamp_original, latent_idx=latent_idx, value=steering_coefficient))
elif hook_action == 'print':
sae.add_hook("hook_sae_acts_post", partial(print_sae_acts))
else:
raise ValueError(f"Unknown hook type: {hook_action}")
# Create and return the model
return cls(base_name=base_name, device=device, model=model)
def forward(self, *args, **kwargs):
# Handle both input_ids and direct tensor inputs
if "input_ids" in kwargs:
input_tensor = kwargs.pop("input_ids") # Use pop to remove it
elif args:
input_tensor = args[0]
args = args[1:] # Remove the first argument
else:
input_tensor = None
with torch.no_grad(): # I don't know why this no grad is necessary; I tried putting everything into eval mode. And yet, this is necessary to prevent CUDA out of memory exceptions.
output = self.model.forward(input_tensor, *args, **kwargs)
return output
@register_model("sae_steered_beta")
class InterventionModelLM(HFLM):
def __init__(self, base_name, csv_path, **kwargs):
self.swap_in_model = InterventionModel.from_csv(
csv_path=csv_path, base_name=base_name, device=kwargs.get("device", "cuda")
)
self.swap_in_model.eval()
# Initialize other necessary attributes
super().__init__(pretrained=base_name, **kwargs)
if hasattr(self, "_model"):
# Delete all the model's parameters but keep the object
for param in self._model.parameters():
param.data.zero_()
param.requires_grad = False
# Remove all model modules while keeping the base object
for name, module in list(self._model.named_children()):
delattr(self._model, name)
torch.cuda.empty_cache()
def _model_call(self, inputs):
return self.swap_in_model.forward(inputs)
......@@ -23,6 +23,7 @@ dependencies = [
"evaluate",
"datasets>=2.16.0",
"evaluate>=0.4.0",
"jaxtyping",
"jsonlines",
"numexpr",
"peft>=0.2.0",
......@@ -30,11 +31,13 @@ dependencies = [
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu>=1.5.0",
"sae_lens",
"scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.8",
"tqdm-multiprocess",
"transformers>=4.1",
"transformer-lens>=2.7.0",
"zstandard",
"dill",
"word2number",
......
{
"results": {
"mmlu_abstract_algebra": {
"alias": "abstract_algebra",
"acc,none": 0.34,
"acc_stderr,none": 0.047609522856952344
}
},
"group_subtasks": {
"mmlu_abstract_algebra": []
},
"configs": {
"mmlu_abstract_algebra": {
"task": "mmlu_abstract_algebra",
"task_alias": "abstract_algebra",
"tag": "mmlu_stem_tasks",
"dataset_path": "hails/mmlu_no_train",
"dataset_name": "abstract_algebra",
"dataset_kwargs": {
"trust_remote_code": true
},
"test_split": "test",
"fewshot_split": "dev",
"doc_to_text": "{{question.strip()}}\nA. {{choices[0]}}\nB. {{choices[1]}}\nC. {{choices[2]}}\nD. {{choices[3]}}\nAnswer:",
"doc_to_target": "answer",
"unsafe_code": false,
"doc_to_choice": [
"A",
"B",
"C",
"D"
],
"description": "The following are multiple choice questions (with answers) about abstract algebra.\n\n",
"target_delimiter": " ",
"fewshot_delimiter": "\n\n",
"fewshot_config": {
"sampler": "first_n"
},
"num_fewshot": 0,
"metric_list": [
{
"metric": "acc",
"aggregation": "mean",
"higher_is_better": true
}
],
"output_type": "multiple_choice",
"repeats": 1,
"should_decontaminate": false,
"metadata": {
"version": 1.0
}
}
},
"versions": {
"mmlu_abstract_algebra": 1.0
},
"n-shot": {
"mmlu_abstract_algebra": 0
},
"higher_is_better": {
"mmlu_abstract_algebra": {
"acc": true
}
},
"n-samples": {
"mmlu_abstract_algebra": {
"original": 100,
"effective": 100
}
},
"config": {
"model": "sae_steered_beta",
"model_args": "base_name=google/gemma-2-2b,csv_path=/home/cs29824/matthew/lm-evaluation-harness/examples/dog_steer.csv",
"model_num_parameters": 0,
"model_dtype": null,
"model_revision": "main",
"model_sha": "c5ebcd40d208330abc697524c919956e692655cf",
"batch_size": "auto",
"batch_sizes": [
16
],
"device": "cuda:0",
"use_cache": null,
"limit": null,
"bootstrap_iters": 100000,
"gen_kwargs": null,
"random_seed": 0,
"numpy_seed": 1234,
"torch_seed": 1234,
"fewshot_seed": 1234
},
"git_hash": "e16afa2f",
"date": 1737419939.4888458,
"pretty_env_info": "PyTorch version: 2.5.1+cu124\nIs debug build: False\nCUDA used to build PyTorch: 12.4\nROCM used to build PyTorch: N/A\n\nOS: Ubuntu 20.04.6 LTS (x86_64)\nGCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0\nClang version: Could not collect\nCMake version: version 3.16.3\nLibc version: glibc-2.31\n\nPython version: 3.11.11 | packaged by conda-forge | (main, Dec 5 2024, 14:17:24) [GCC 13.3.0] (64-bit runtime)\nPython platform: Linux-5.4.0-1125-kvm-x86_64-with-glibc2.31\nIs CUDA available: True\nCUDA runtime version: 10.1.243\nCUDA_MODULE_LOADING set to: LAZY\nGPU models and configuration: \nGPU 0: Quadro RTX 8000\nGPU 1: Quadro RTX 8000\n\nNvidia driver version: 545.23.08\ncuDNN version: Probably one of the following:\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1\nHIP runtime version: N/A\nMIOpen runtime version: N/A\nIs XNNPACK available: True\n\nCPU:\nArchitecture: x86_64\nCPU op-mode(s): 32-bit, 64-bit\nByte Order: Little Endian\nAddress sizes: 46 bits physical, 48 bits virtual\nCPU(s): 16\nOn-line CPU(s) list: 0-15\nThread(s) per core: 1\nCore(s) per socket: 1\nSocket(s): 16\nNUMA node(s): 1\nVendor ID: GenuineIntel\nCPU family: 6\nModel: 85\nModel name: Intel Xeon Processor (Cascadelake)\nStepping: 6\nCPU MHz: 2294.608\nBogoMIPS: 4589.21\nVirtualization: VT-x\nHypervisor vendor: KVM\nVirtualization type: full\nL1d cache: 512 KiB\nL1i cache: 512 KiB\nL2 cache: 64 MiB\nL3 cache: 256 MiB\nNUMA node0 CPU(s): 0-15\nVulnerability Gather data sampling: Unknown: Dependent on hypervisor status\nVulnerability Itlb multihit: Not affected\nVulnerability L1tf: Not affected\nVulnerability Mds: Not affected\nVulnerability Meltdown: Not affected\nVulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown\nVulnerability Retbleed: Mitigation; Enhanced IBRS\nVulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp\nVulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization\nVulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI Vulnerable, KVM SW loop\nVulnerability Srbds: Not affected\nVulnerability Tsx async abort: Mitigation; TSX disabled\nFlags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat umip pku avx512_vnni md_clear arch_capabilities\n\nVersions of relevant libraries:\n[pip3] mypy==1.14.1\n[pip3] mypy-extensions==1.0.0\n[pip3] numpy==1.26.4\n[pip3] torch==2.5.1\n[pip3] triton==3.1.0\n[conda] numpy 1.26.4 pypi_0 pypi\n[conda] torch 2.5.1 pypi_0 pypi\n[conda] triton 3.1.0 pypi_0 pypi",
"transformers_version": "4.48.1",
"upper_git_hash": null,
"tokenizer_pad_token": [
"<pad>",
"0"
],
"tokenizer_eos_token": [
"<eos>",
"1"
],
"tokenizer_bos_token": [
"<bos>",
"2"
],
"eot_token_id": 1,
"max_length": 8192,
"task_hashes": {},
"model_source": "sae_steered_beta",
"model_name": "/home/cs29824/matthew/lm-evaluation-harness/examples/dog_steer.csv",
"model_name_sanitized": "__home__cs29824__matthew__lm-evaluation-harness__examples__dog_steer.csv",
"system_instruction": null,
"system_instruction_sha": null,
"fewshot_as_multiturn": false,
"chat_template": null,
"chat_template_sha": null,
"start_time": 2970008.635285475,
"end_time": 2970078.697630497,
"total_evaluation_time_seconds": "70.06234502233565"
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment