"vscode:/vscode.git/clone" did not exist on "288a938872cc3c6150a486aaa15a3b5dcadf42cc"
Unverified Commit fa6ae311 authored by Jesus Federico's avatar Jesus Federico Committed by GitHub
Browse files

feat: rename logit_bias/logit_scale to logit_mean/logit_sigma for affine score calibration (#39530)


Signed-off-by: default avatarJesus Federico <jefp@amazon.com>
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 2a3c32ce
......@@ -273,27 +273,25 @@ Affine Score Calibration, also known as [Platt Scaling](https://en.wikipedia.org
The calibration follows the transformation:
`activation(logit_scale * (logit - logit_bias))`
`activation((logit - logit_mean) / logit_sigma)`
| Parameter | Default | Description |
| --------- | ------- | ----------- |
| `logit_bias` | `None` | Bias subtracted from logits before activation |
| `logit_scale` | `None` | Scale factor applied to logits after bias subtraction |
Note: `logit_bias` is **subtracted** from the logits (not added), consistent with the `sigmoid_normalize` convention where `sigmoid(x - bias)` centers the sigmoid around the bias value.
| `logit_mean` | `None` | Mean subtracted from logits (centers scores) |
| `logit_sigma` | `None` | Standard deviation used to scale logits after mean subtraction |
The computation order is as follows:
```python
logits -= logit_bias # subtract bias (center scores)
logits *= logit_scale # scale logits
logits -= logit_mean # subtract mean (center scores)
logits /= logit_sigma # divide by sigma (scale)
logits = activation(logits) # e.g. sigmoid
```
Example configuration:
```bash
--pooler-config '{"use_activation": true, "logit_bias": 4.5, "logit_scale": 1.0}'
--pooler-config '{"use_activation": true, "logit_mean": 4.5, "logit_sigma": 1.0}'
```
## Removed Features
......@@ -301,3 +299,7 @@ Example configuration:
### Remove softmax from PoolingParams
We have already removed `softmax` and `activation` from PoolingParams. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
### Remove `logit_bias` and `logit_scale`
`logit_bias` and `logit_scale` are deprecated aliases for `logit_mean` and `logit_sigma` respectively. When using `logit_scale`, it is automatically converted to `logit_sigma = 1/logit_scale`. These deprecated parameters will be removed in v0.21.
......@@ -77,17 +77,31 @@ class PoolerConfig:
Defaults to None (i.e. set to max_model_len).
"""
## for classification models
## for classification models — affine score calibration
logit_mean: float | None = None
"""
If provided, subtract this value from classification logits before
activation. Used for affine score calibration (Platt scaling):
activation((logit - logit_mean) / logit_sigma). Defaults to None.
"""
logit_sigma: float | None = None
"""
If provided, divide the classification logits by this value after
mean subtraction. Used for affine score calibration (Platt scaling):
activation((logit - logit_mean) / logit_sigma). Defaults to None.
"""
# Deprecated aliases — will be removed in v0.21
logit_bias: float | None = None
"""
If provided, apply classification logit biases. Defaults to None.
Deprecated: Use logit_mean instead. Will be removed in v0.21.
"""
logit_scale: float | None = None
"""
If provided, scale the classification logits by this factor before
activation. Combined with logit_bias, enables affine score calibration:
activation(logit_scale * (score - logit_bias)). Defaults to None.
Deprecated: Use logit_sigma instead (note: logit_sigma = 1/logit_scale).
Will be removed in v0.21.
"""
## for reward models
......@@ -105,6 +119,39 @@ class PoolerConfig:
"""
def __post_init__(self) -> None:
# Handle deprecated logit_bias → logit_mean
if self.logit_bias is not None:
if self.logit_mean is not None:
raise ValueError(
"Cannot set both `logit_bias` and `logit_mean`. "
"`logit_bias` is deprecated, use `logit_mean` instead."
)
logger.warning(
"`logit_bias` is deprecated and will be removed in v0.21. "
"Use `logit_mean` instead."
)
self.logit_mean = self.logit_bias
self.logit_bias = None
# Handle deprecated logit_scale → logit_sigma
if self.logit_scale is not None:
if self.logit_sigma is not None:
raise ValueError(
"Cannot set both `logit_scale` and `logit_sigma`. "
"`logit_scale` is deprecated, use `logit_sigma` instead."
)
logger.warning(
"`logit_scale` is deprecated and will be removed in v0.21. "
"Use `logit_sigma` instead (logit_sigma = 1/logit_scale)."
)
if self.logit_scale == 0:
raise ValueError("logit_scale cannot be 0 (division by zero)")
self.logit_sigma = 1.0 / self.logit_scale
self.logit_scale = None
if self.logit_sigma is not None and self.logit_sigma == 0:
raise ValueError("logit_sigma cannot be 0 (division by zero)")
if pooling_type := self.pooling_type:
if self.seq_pooling_type is not None:
raise ValueError(
......
......@@ -103,16 +103,16 @@ class ClassifierPoolerHead(SequencePoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
logit_bias: float | None = None,
logit_scale: float | None = None,
logit_mean: float | None = None,
logit_sigma: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
self.classifier = classifier
self.logit_bias = logit_bias
self.logit_scale = logit_scale
self.logit_mean = logit_mean
self.logit_sigma = logit_sigma
self.head_dtype = head_dtype
self.activation = activation
......@@ -140,10 +140,11 @@ class ClassifierPoolerHead(SequencePoolerHead):
logits = pooled_data
# logits shape: [batchsize, num_labels]
if self.logit_bias is not None:
logits -= self.logit_bias
if self.logit_scale is not None:
logits *= self.logit_scale
# Affine score calibration: activation((logit - mean) / sigma)
if self.logit_mean is not None:
logits = logits - self.logit_mean
if self.logit_sigma is not None:
logits = logits / self.logit_sigma
if self.activation is not None:
flags = [p.use_activation for p in pooling_params]
......
......@@ -118,8 +118,8 @@ def pooler_for_classify(
head = ClassifierPoolerHead(
head_dtype=model_config.head_dtype,
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
logit_scale=model_config.pooler_config.logit_scale,
logit_mean=model_config.pooler_config.logit_mean,
logit_sigma=model_config.pooler_config.logit_sigma,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
),
......
......@@ -92,16 +92,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
logit_bias: float | None = None,
logit_scale: float | None = None,
logit_mean: float | None = None,
logit_sigma: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
self.classifier = classifier
self.logit_bias = logit_bias
self.logit_scale = logit_scale
self.logit_mean = logit_mean
self.logit_sigma = logit_sigma
self.head_dtype = head_dtype
self.activation = activation
......@@ -127,10 +127,11 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
logits = pooled_data
# logits shape: [n_token, num_labels]
if self.logit_bias is not None:
logits -= self.logit_bias
if self.logit_scale is not None:
logits *= self.logit_scale
# Affine score calibration: activation((logit - mean) / sigma)
if self.logit_mean is not None:
logits = logits - self.logit_mean
if self.logit_sigma is not None:
logits = logits / self.logit_sigma
if self.activation is not None and pooling_param.use_activation:
logits = self.activation(logits)
......
......@@ -127,8 +127,8 @@ def pooler_for_token_classify(
head = TokenClassifierPoolerHead(
head_dtype=model_config.head_dtype,
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
logit_scale=model_config.pooler_config.logit_scale,
logit_mean=model_config.pooler_config.logit_mean,
logit_sigma=model_config.pooler_config.logit_sigma,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
),
......
......@@ -232,8 +232,8 @@ class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
config = model_config.hf_config
config.num_labels = 1
pooler_config = model_config.pooler_config
if pooler_config.logit_bias is None:
pooler_config.logit_bias = 2.65
if pooler_config.logit_mean is None:
pooler_config.logit_mean = 2.65
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment