Unverified Commit a46180bf authored by Lucia Quirke's avatar Lucia Quirke Committed by GitHub
Browse files

Add support for steering specific attention heads (#3279)

parent 2d7cb5c3
......@@ -71,13 +71,6 @@ class SteeredModel(HFLM):
"""
HFLM with a steered forward pass.
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
provide the path to a CSV file with the following columns (example rows are provided below):
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,sae_id,description,
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,layer_20/width_16k/canonical,increase dogs,
To load steering vectors directly, provide the path to a pytorch (.pt) file with content in the following format:
{
......@@ -86,9 +79,17 @@ class SteeredModel(HFLM):
"steering_coefficient": <float>,
"action": <Literal["add", "clamp"]>,
"bias": <torch.Tensor | None>,
"head_index": <int | None>,
},
...
}
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
provide the path to a CSV file with the following columns (example rows are provided below):
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,head_index,sae_id,description,
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,,
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,,layer_20/width_16k/canonical,increase dogs,
"""
super().__init__(pretrained=pretrained, device=device, **kwargs)
......@@ -105,27 +106,31 @@ class SteeredModel(HFLM):
hook_to_steer = {}
for hookpoint, steer_info in steer_config.items():
action = steer_info["action"]
steering_coefficient = steer_info["steering_coefficient"]
steering_vector = (
steer_info["steering_vector"].to(self.device).to(self.model.dtype)
)
bias = (
steer_info["bias"].to(self.device).to(self.model.dtype)
if steer_info["bias"] is not None
else None
)
steering_coefficient = float(steer_info.get("steering_coefficient", 1.0))
head_index = steer_info.get("head_index", None)
bias = steer_info.get("bias", None)
if bias is not None:
bias = bias.to(self.device).to(self.model.dtype)
if action == "add":
# Steers the model by adding some multiple of a steering vector to all sequence positions.
hook_to_steer[hookpoint] = (
lambda acts: acts + steering_coefficient * steering_vector
# Steer the model by adding a multiple of a steering vector to all sequence positions.
assert bias is None, "Bias is not supported for the `add` action."
hook_to_steer[hookpoint] = partial(
self.add,
vector=steering_vector * steering_coefficient,
head_index=head_index,
)
elif action == "clamp":
# Steer the model by clamping the activations to a value in the direction of the steering vector.
hook_to_steer[hookpoint] = partial(
self.clamp,
steering_vector=steering_vector,
direction=steering_vector / torch.norm(steering_vector),
value=steering_coefficient,
bias=bias,
head_index=head_index,
)
else:
raise ValueError(f"Unknown hook type: {action}")
......@@ -195,34 +200,62 @@ class SteeredModel(HFLM):
return steer_data
@classmethod
def add(
cls,
acts: Tensor,
vector: Tensor,
head_index: Optional[int],
):
"""Adds the given vector to the activations.
Args:
acts (Tensor): The activations tensor to edit of shape [batch, pos, ..., features]
vector (Tensor): A vector to add of shape [features]
head_index (int | None): Optional attention head index to add to
"""
if head_index is not None:
acts[:, :, head_index, :] = acts[:, :, head_index, :] + vector
else:
acts = acts + vector
return acts
@classmethod
def clamp(
cls,
acts: Tensor,
steering_vector: Tensor,
direction: Tensor,
value: float,
head_index: Optional[int],
bias: Optional[Tensor] = None,
):
"""Clamps a direction of the activations to be the steering vector * the value.
"""Clamps the activations to a given value in a specified direction. The direction
must be a unit vector.
Args:
acts (Tensor): The activations tensor to edit of shape [batch, pos, features]
steering_vector (Tensor): A direction to clamp of shape [features]
acts (Tensor): The activations tensor to edit of shape [batch, pos, ..., features]
direction (Tensor): A direction to clamp of shape [features]
value (float): Value to clamp the direction to
head_index (int | None): Optional attention head index to clamp
bias (Tensor | None): Optional bias to add to the activations
Returns:
Tensor: The modified activations with the specified direction clamped
"""
if bias is not None:
acts = acts - bias
direction = steering_vector / torch.norm(steering_vector)
proj_magnitude = torch.sum(acts * direction, dim=-1, keepdim=True)
orthogonal_component = acts - proj_magnitude * direction
if head_index is not None:
x = acts[:, :, head_index, :]
proj = (x * direction).sum(dim=-1, keepdim=True)
assert proj == acts @ direction
clamped = orthogonal_component + direction * value
clamped = acts.clone()
clamped[:, :, head_index, :] = x + direction * (value - proj)
else:
proj = torch.sum(acts * direction, dim=-1, keepdim=True)
clamped = acts + direction * (value - proj)
if bias is not None:
return clamped + bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment