sae_steered_beta.py 8.49 KB
Newer Older
Stella Biderman's avatar
Stella Biderman committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""Credit: contributed by https://github.com/AMindToThink aka Matthew Khoriaty of Northwestern University."""

from functools import partial

import torch
from jaxtyping import Float
from sae_lens import SAE, HookedSAETransformer
from torch import Tensor
from transformer_lens import loading_from_pretrained
from transformer_lens.hook_points import HookPoint

from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM

def steering_hook_add_scaled_one_hot(
    activations,#: Float[Tensor],  # Float[Tensor, "batch pos d_in"], Either jaxtyping or lm-evaluation-harness' precommit git script hate a type hint here.
    hook: HookPoint,
    sae: SAE,
    latent_idx: int,
    steering_coefficient: float,
) -> Tensor:
    """
    Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
    sequence positions.
    """
    return activations + steering_coefficient * sae.W_dec[latent_idx]

# def steering_hook_clamp(
#     activations,#: Float[Tensor],  # Float[Tensor, "batch pos d_in"], Either jaxtyping or lm-evaluation-harness' precommit git script hate a type hint here.
#     hook: HookPoint,
#     sae: SAE,
#     latent_idx: int,
#     steering_coefficient: float,
# ) -> Tensor:
#     """
#     Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
#     sequence positions.
#     """
#     raise NotImplemented
#     z = sae.encode(activations)
#     z[latent_idx] = steering_coefficient
#     return sae.decode(activations)
#     return activations + steering_coefficient * sae.W_dec[latent_idx]


def clamp_sae_feature(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:float) -> Tensor:
    """Clamps a specific latent feature in the SAE activations to a fixed value.

    Args:
        sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
        hook (HookPoint): The transformer-lens hook point
        latent_idx (int): Index of the latent feature to clamp
        value (float): Value to clamp the feature to

    Returns:
        Tensor: The modified SAE activations with the specified feature clamped
    """
    
    sae_acts[:, :, latent_idx] = value

    return sae_acts

def clamp_original(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:float) -> Tensor:
    """Clamps a specific latent feature in the SAE activations to a fixed value.

    Args:
        sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
        hook (HookPoint): The transformer-lens hook point
        latent_idx (int): Index of the latent feature to clamp
        value (float): Value to clamp the feature to

    Returns:
        Tensor: The modified SAE activations with the specified feature clamped
    """
    
    mask = sae_acts[:, :, latent_idx] > 0  # Create a boolean mask where values are greater than 0
    sae_acts[:, :, latent_idx][mask] = value  # Replace values conditionally

    return sae_acts

def print_sae_acts(sae_acts:Tensor, hook:HookPoint) -> Tensor:
    """Clamps a specific latent feature in the SAE activations to a fixed value.

    Args:
        sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
        hook (HookPoint): The transformer-lens hook point
        latent_idx (int): Index of the latent feature to clamp
        value (float): Value to clamp the feature to

    Returns:
        Tensor: The modified SAE activations with the specified feature clamped
    """
    print(40*"----")
    print(f"This is the latent activations of {hook.name}")
    print(sae_acts.shape)
    print(torch.all(sae_acts > 0))
    return sae_acts

string_to_steering_function_dict : dict = {'add':steering_hook_add_scaled_one_hot, 'clamp':clamp_sae_feature}

class InterventionModel(HookedSAETransformer):  # Replace with the specific model class
    def __init__(self, base_name: str, device: str = "cuda:0", model=None):
        trueconfig = loading_from_pretrained.get_pretrained_model_config(
            base_name, device=device
        )
        super().__init__(trueconfig)
        self.model = model or HookedSAETransformer.from_pretrained(base_name, device=device)
        self.model.use_error_term = True
        self.model.eval()
        self.device = device  # Add device attribute
        self.to(device)  # Ensure model is on the correct device

    @classmethod
    def from_csv(
        cls, csv_path: str, base_name: str, device: str = "cuda:0"
    ) -> "InterventionModel":
        """
        Create an InterventionModel from a CSV file containing steering configurations.

        Expected CSV format:
        index, coefficient, sae_release, sae_id, description
        12082, 240.0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical, increase dogs
        ...

        Args:
            csv_path: Path to the CSV file containing steering configurations
            device: Device to place the model on

        Returns:
            InterventionModel with configured steering hooks
        """
        import pandas as pd
        model = HookedSAETransformer.from_pretrained(base_name, device=device)
        # Read steering configurations
        df = pd.read_csv(csv_path)
        # Create hooks for each row in the CSV
        sae_cache = {}
        hooks = []

        def get_sae(sae_release, sae_id):
            cache_key = (sae_release, sae_id)
            if cache_key not in sae_cache:
                sae_cache[cache_key] = SAE.from_pretrained(
                    sae_release, sae_id, device=str(device)
                )[0]
            return sae_cache[cache_key]

        for _, row in df.iterrows():
            sae_release = row["sae_release"]
            sae_id = row["sae_id"]
            latent_idx = int(row["latent_idx"])
            steering_coefficient = float(row["steering_coefficient"])
            sae = get_sae(sae_release=sae_release, sae_id=sae_id)
            sae.use_error_term = True
            sae.eval()
            model.add_sae(sae)
            hook_action = row.get("hook_action", "add")
            if hook_action == "add":
                hook_name = f"{sae.cfg.hook_name}.hook_sae_input" # we aren't actually putting the input through the model
                hook = partial(steering_hook_add_scaled_one_hot,
                               sae=sae,
                               latent_idx=latent_idx,
                               steering_coefficient=steering_coefficient,
                              )
                model.add_hook(hook_name, hook)
            elif hook_action == "clamp":
                sae.add_hook("hook_sae_acts_post", partial(clamp_original, latent_idx=latent_idx, value=steering_coefficient))
            elif hook_action == 'print':
                sae.add_hook("hook_sae_acts_post", partial(print_sae_acts))
            else:
                raise ValueError(f"Unknown hook type: {hook_action}")
            
            

        # Create and return the model
        return cls(base_name=base_name, device=device, model=model)

    def forward(self, *args, **kwargs):
        # Handle both input_ids and direct tensor inputs
        if "input_ids" in kwargs:
            input_tensor = kwargs.pop("input_ids")  # Use pop to remove it
        elif args:
            input_tensor = args[0]
            args = args[1:]  # Remove the first argument
        else:
            input_tensor = None
        with torch.no_grad():  # I don't know why this no grad is necessary; I tried putting everything into eval mode. And yet, this is necessary to prevent CUDA out of memory exceptions.
            output = self.model.forward(input_tensor, *args, **kwargs)
        return output


@register_model("sae_steered_beta")
class InterventionModelLM(HFLM):
    def __init__(self, base_name, csv_path, **kwargs):
        self.swap_in_model = InterventionModel.from_csv(
            csv_path=csv_path, base_name=base_name, device=kwargs.get("device", "cuda")
        )
        self.swap_in_model.eval()
        # Initialize other necessary attributes
        super().__init__(pretrained=base_name, **kwargs)
        if hasattr(self, "_model"):
            # Delete all the model's parameters but keep the object
            for param in self._model.parameters():
                param.data.zero_()
                param.requires_grad = False
            # Remove all model modules while keeping the base object
            for name, module in list(self._model.named_children()):
                delattr(self._model, name)
            torch.cuda.empty_cache()

    def _model_call(self, inputs):
        return self.swap_in_model.forward(inputs)