Unverified Commit b87575d2 authored by Jesus Federico's avatar Jesus Federico Committed by GitHub
Browse files

feat: add logit_scale to PoolerConfig for affine score calibration (#39435)


Signed-off-by: default avatarJesus Federico <jefp@amazon.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 42c6bb4b
...@@ -267,9 +267,34 @@ You can modify the `problem_type` via problem_type in the Hugging Face config. T ...@@ -267,9 +267,34 @@ You can modify the `problem_type` via problem_type in the Hugging Face config. T
Implement alignment with transformers [ForSequenceClassificationLoss](https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92). Implement alignment with transformers [ForSequenceClassificationLoss](https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92).
### Logit bias ### Affine Score Calibration
You can modify the `logit_bias` (aka `sigmoid_normalize`) through the logit_bias parameter in `vllm.config.PoolerConfig`. Affine Score Calibration, also known as [Platt Scaling](https://en.wikipedia.org/wiki/Platt_scaling) (Platt, 1999), is the most widely used method for calibrating classifier outputs into well-calibrated probabilities.
The calibration follows the transformation:
`activation(logit_scale * (logit - logit_bias))`
| Parameter | Default | Description |
| --------- | ------- | ----------- |
| `logit_bias` | `None` | Bias subtracted from logits before activation |
| `logit_scale` | `None` | Scale factor applied to logits after bias subtraction |
Note: `logit_bias` is **subtracted** from the logits (not added), consistent with the `sigmoid_normalize` convention where `sigmoid(x - bias)` centers the sigmoid around the bias value.
The computation order is as follows:
```python
logits -= logit_bias # subtract bias (center scores)
logits *= logit_scale # scale logits
logits = activation(logits) # e.g. sigmoid
```
Example configuration:
```bash
--pooler-config '{"use_activation": true, "logit_bias": 4.5, "logit_scale": 1.0}'
```
## Removed Features ## Removed Features
......
...@@ -83,6 +83,13 @@ class PoolerConfig: ...@@ -83,6 +83,13 @@ class PoolerConfig:
If provided, apply classification logit biases. Defaults to None. If provided, apply classification logit biases. Defaults to None.
""" """
logit_scale: float | None = None
"""
If provided, scale the classification logits by this factor before
activation. Combined with logit_bias, enables affine score calibration:
activation(logit_scale * (score - logit_bias)). Defaults to None.
"""
## for reward models ## for reward models
step_tag_id: int | None = None step_tag_id: int | None = None
""" """
......
...@@ -104,6 +104,7 @@ class ClassifierPoolerHead(SequencePoolerHead): ...@@ -104,6 +104,7 @@ class ClassifierPoolerHead(SequencePoolerHead):
self, self,
classifier: ClassifierFn | None = None, classifier: ClassifierFn | None = None,
logit_bias: float | None = None, logit_bias: float | None = None,
logit_scale: float | None = None,
head_dtype: torch.dtype | str | None = None, head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None, activation: ActivationFn | None = None,
) -> None: ) -> None:
...@@ -111,6 +112,7 @@ class ClassifierPoolerHead(SequencePoolerHead): ...@@ -111,6 +112,7 @@ class ClassifierPoolerHead(SequencePoolerHead):
self.classifier = classifier self.classifier = classifier
self.logit_bias = logit_bias self.logit_bias = logit_bias
self.logit_scale = logit_scale
self.head_dtype = head_dtype self.head_dtype = head_dtype
self.activation = activation self.activation = activation
...@@ -140,6 +142,8 @@ class ClassifierPoolerHead(SequencePoolerHead): ...@@ -140,6 +142,8 @@ class ClassifierPoolerHead(SequencePoolerHead):
# logits shape: [batchsize, num_labels] # logits shape: [batchsize, num_labels]
if self.logit_bias is not None: if self.logit_bias is not None:
logits -= self.logit_bias logits -= self.logit_bias
if self.logit_scale is not None:
logits *= self.logit_scale
if self.activation is not None: if self.activation is not None:
flags = [p.use_activation for p in pooling_params] flags = [p.use_activation for p in pooling_params]
......
...@@ -119,6 +119,7 @@ def pooler_for_classify( ...@@ -119,6 +119,7 @@ def pooler_for_classify(
head_dtype=model_config.head_dtype, head_dtype=model_config.head_dtype,
classifier=classifier, classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias, logit_bias=model_config.pooler_config.logit_bias,
logit_scale=model_config.pooler_config.logit_scale,
activation=resolve_classifier_act_fn( activation=resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn model_config, static_num_labels=True, act_fn=act_fn
), ),
......
...@@ -93,6 +93,7 @@ class TokenClassifierPoolerHead(TokenPoolerHead): ...@@ -93,6 +93,7 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
self, self,
classifier: ClassifierFn | None = None, classifier: ClassifierFn | None = None,
logit_bias: float | None = None, logit_bias: float | None = None,
logit_scale: float | None = None,
head_dtype: torch.dtype | str | None = None, head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None, activation: ActivationFn | None = None,
) -> None: ) -> None:
...@@ -100,6 +101,7 @@ class TokenClassifierPoolerHead(TokenPoolerHead): ...@@ -100,6 +101,7 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
self.classifier = classifier self.classifier = classifier
self.logit_bias = logit_bias self.logit_bias = logit_bias
self.logit_scale = logit_scale
self.head_dtype = head_dtype self.head_dtype = head_dtype
self.activation = activation self.activation = activation
...@@ -127,6 +129,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead): ...@@ -127,6 +129,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if self.logit_bias is not None: if self.logit_bias is not None:
logits -= self.logit_bias logits -= self.logit_bias
if self.logit_scale is not None:
logits *= self.logit_scale
if self.activation is not None and pooling_param.use_activation: if self.activation is not None and pooling_param.use_activation:
logits = self.activation(logits) logits = self.activation(logits)
......
...@@ -128,6 +128,7 @@ def pooler_for_token_classify( ...@@ -128,6 +128,7 @@ def pooler_for_token_classify(
head_dtype=model_config.head_dtype, head_dtype=model_config.head_dtype,
classifier=classifier, classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias, logit_bias=model_config.pooler_config.logit_bias,
logit_scale=model_config.pooler_config.logit_scale,
activation=resolve_classifier_act_fn( activation=resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn model_config, static_num_labels=False, act_fn=act_fn
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment