Unverified Commit a46180bf authored by Lucia Quirke's avatar Lucia Quirke Committed by GitHub
Browse files

Add support for steering specific attention heads (#3279)

parent 2d7cb5c3
...@@ -71,13 +71,6 @@ class SteeredModel(HFLM): ...@@ -71,13 +71,6 @@ class SteeredModel(HFLM):
""" """
HFLM with a steered forward pass. HFLM with a steered forward pass.
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
provide the path to a CSV file with the following columns (example rows are provided below):
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,sae_id,description,
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,layer_20/width_16k/canonical,increase dogs,
To load steering vectors directly, provide the path to a pytorch (.pt) file with content in the following format: To load steering vectors directly, provide the path to a pytorch (.pt) file with content in the following format:
{ {
...@@ -86,9 +79,17 @@ class SteeredModel(HFLM): ...@@ -86,9 +79,17 @@ class SteeredModel(HFLM):
"steering_coefficient": <float>, "steering_coefficient": <float>,
"action": <Literal["add", "clamp"]>, "action": <Literal["add", "clamp"]>,
"bias": <torch.Tensor | None>, "bias": <torch.Tensor | None>,
"head_index": <int | None>,
}, },
... ...
} }
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
provide the path to a CSV file with the following columns (example rows are provided below):
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,head_index,sae_id,description,
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,,
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,,layer_20/width_16k/canonical,increase dogs,
""" """
super().__init__(pretrained=pretrained, device=device, **kwargs) super().__init__(pretrained=pretrained, device=device, **kwargs)
...@@ -105,27 +106,31 @@ class SteeredModel(HFLM): ...@@ -105,27 +106,31 @@ class SteeredModel(HFLM):
hook_to_steer = {} hook_to_steer = {}
for hookpoint, steer_info in steer_config.items(): for hookpoint, steer_info in steer_config.items():
action = steer_info["action"] action = steer_info["action"]
steering_coefficient = steer_info["steering_coefficient"]
steering_vector = ( steering_vector = (
steer_info["steering_vector"].to(self.device).to(self.model.dtype) steer_info["steering_vector"].to(self.device).to(self.model.dtype)
) )
bias = ( steering_coefficient = float(steer_info.get("steering_coefficient", 1.0))
steer_info["bias"].to(self.device).to(self.model.dtype) head_index = steer_info.get("head_index", None)
if steer_info["bias"] is not None bias = steer_info.get("bias", None)
else None if bias is not None:
) bias = bias.to(self.device).to(self.model.dtype)
if action == "add": if action == "add":
# Steers the model by adding some multiple of a steering vector to all sequence positions. # Steer the model by adding a multiple of a steering vector to all sequence positions.
hook_to_steer[hookpoint] = ( assert bias is None, "Bias is not supported for the `add` action."
lambda acts: acts + steering_coefficient * steering_vector hook_to_steer[hookpoint] = partial(
self.add,
vector=steering_vector * steering_coefficient,
head_index=head_index,
) )
elif action == "clamp": elif action == "clamp":
# Steer the model by clamping the activations to a value in the direction of the steering vector.
hook_to_steer[hookpoint] = partial( hook_to_steer[hookpoint] = partial(
self.clamp, self.clamp,
steering_vector=steering_vector, direction=steering_vector / torch.norm(steering_vector),
value=steering_coefficient, value=steering_coefficient,
bias=bias, bias=bias,
head_index=head_index,
) )
else: else:
raise ValueError(f"Unknown hook type: {action}") raise ValueError(f"Unknown hook type: {action}")
...@@ -195,34 +200,62 @@ class SteeredModel(HFLM): ...@@ -195,34 +200,62 @@ class SteeredModel(HFLM):
return steer_data return steer_data
@classmethod
def add(
cls,
acts: Tensor,
vector: Tensor,
head_index: Optional[int],
):
"""Adds the given vector to the activations.
Args:
acts (Tensor): The activations tensor to edit of shape [batch, pos, ..., features]
vector (Tensor): A vector to add of shape [features]
head_index (int | None): Optional attention head index to add to
"""
if head_index is not None:
acts[:, :, head_index, :] = acts[:, :, head_index, :] + vector
else:
acts = acts + vector
return acts
@classmethod @classmethod
def clamp( def clamp(
cls, cls,
acts: Tensor, acts: Tensor,
steering_vector: Tensor, direction: Tensor,
value: float, value: float,
head_index: Optional[int],
bias: Optional[Tensor] = None, bias: Optional[Tensor] = None,
): ):
"""Clamps a direction of the activations to be the steering vector * the value. """Clamps the activations to a given value in a specified direction. The direction
must be a unit vector.
Args: Args:
acts (Tensor): The activations tensor to edit of shape [batch, pos, features] acts (Tensor): The activations tensor to edit of shape [batch, pos, ..., features]
steering_vector (Tensor): A direction to clamp of shape [features] direction (Tensor): A direction to clamp of shape [features]
value (float): Value to clamp the direction to value (float): Value to clamp the direction to
head_index (int | None): Optional attention head index to clamp
bias (Tensor | None): Optional bias to add to the activations bias (Tensor | None): Optional bias to add to the activations
Returns: Returns:
Tensor: The modified activations with the specified direction clamped Tensor: The modified activations with the specified direction clamped
""" """
if bias is not None: if bias is not None:
acts = acts - bias acts = acts - bias
direction = steering_vector / torch.norm(steering_vector) if head_index is not None:
proj_magnitude = torch.sum(acts * direction, dim=-1, keepdim=True) x = acts[:, :, head_index, :]
orthogonal_component = acts - proj_magnitude * direction proj = (x * direction).sum(dim=-1, keepdim=True)
assert proj == acts @ direction
clamped = orthogonal_component + direction * value clamped = acts.clone()
clamped[:, :, head_index, :] = x + direction * (value - proj)
else:
proj = torch.sum(acts * direction, dim=-1, keepdim=True)
clamped = acts + direction * (value - proj)
if bias is not None: if bias is not None:
return clamped + bias return clamped + bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment