Unverified Commit fa6ae311 authored by Jesus Federico's avatar Jesus Federico Committed by GitHub
Browse files

feat: rename logit_bias/logit_scale to logit_mean/logit_sigma for affine score calibration (#39530)


Signed-off-by: default avatarJesus Federico <jefp@amazon.com>
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 2a3c32ce
...@@ -273,27 +273,25 @@ Affine Score Calibration, also known as [Platt Scaling](https://en.wikipedia.org ...@@ -273,27 +273,25 @@ Affine Score Calibration, also known as [Platt Scaling](https://en.wikipedia.org
The calibration follows the transformation: The calibration follows the transformation:
`activation(logit_scale * (logit - logit_bias))` `activation((logit - logit_mean) / logit_sigma)`
| Parameter | Default | Description | | Parameter | Default | Description |
| --------- | ------- | ----------- | | --------- | ------- | ----------- |
| `logit_bias` | `None` | Bias subtracted from logits before activation | | `logit_mean` | `None` | Mean subtracted from logits (centers scores) |
| `logit_scale` | `None` | Scale factor applied to logits after bias subtraction | | `logit_sigma` | `None` | Standard deviation used to scale logits after mean subtraction |
Note: `logit_bias` is **subtracted** from the logits (not added), consistent with the `sigmoid_normalize` convention where `sigmoid(x - bias)` centers the sigmoid around the bias value.
The computation order is as follows: The computation order is as follows:
```python ```python
logits -= logit_bias # subtract bias (center scores) logits -= logit_mean # subtract mean (center scores)
logits *= logit_scale # scale logits logits /= logit_sigma # divide by sigma (scale)
logits = activation(logits) # e.g. sigmoid logits = activation(logits) # e.g. sigmoid
``` ```
Example configuration: Example configuration:
```bash ```bash
--pooler-config '{"use_activation": true, "logit_bias": 4.5, "logit_scale": 1.0}' --pooler-config '{"use_activation": true, "logit_mean": 4.5, "logit_sigma": 1.0}'
``` ```
## Removed Features ## Removed Features
...@@ -301,3 +299,7 @@ Example configuration: ...@@ -301,3 +299,7 @@ Example configuration:
### Remove softmax from PoolingParams ### Remove softmax from PoolingParams
We have already removed `softmax` and `activation` from PoolingParams. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function. We have already removed `softmax` and `activation` from PoolingParams. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
### Remove `logit_bias` and `logit_scale`
`logit_bias` and `logit_scale` are deprecated aliases for `logit_mean` and `logit_sigma` respectively. When using `logit_scale`, it is automatically converted to `logit_sigma = 1/logit_scale`. These deprecated parameters will be removed in v0.21.
...@@ -77,17 +77,31 @@ class PoolerConfig: ...@@ -77,17 +77,31 @@ class PoolerConfig:
Defaults to None (i.e. set to max_model_len). Defaults to None (i.e. set to max_model_len).
""" """
## for classification models ## for classification models — affine score calibration
logit_mean: float | None = None
"""
If provided, subtract this value from classification logits before
activation. Used for affine score calibration (Platt scaling):
activation((logit - logit_mean) / logit_sigma). Defaults to None.
"""
logit_sigma: float | None = None
"""
If provided, divide the classification logits by this value after
mean subtraction. Used for affine score calibration (Platt scaling):
activation((logit - logit_mean) / logit_sigma). Defaults to None.
"""
# Deprecated aliases — will be removed in v0.21
logit_bias: float | None = None logit_bias: float | None = None
""" """
If provided, apply classification logit biases. Defaults to None. Deprecated: Use logit_mean instead. Will be removed in v0.21.
""" """
logit_scale: float | None = None logit_scale: float | None = None
""" """
If provided, scale the classification logits by this factor before Deprecated: Use logit_sigma instead (note: logit_sigma = 1/logit_scale).
activation. Combined with logit_bias, enables affine score calibration: Will be removed in v0.21.
activation(logit_scale * (score - logit_bias)). Defaults to None.
""" """
## for reward models ## for reward models
...@@ -105,6 +119,39 @@ class PoolerConfig: ...@@ -105,6 +119,39 @@ class PoolerConfig:
""" """
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Handle deprecated logit_bias → logit_mean
if self.logit_bias is not None:
if self.logit_mean is not None:
raise ValueError(
"Cannot set both `logit_bias` and `logit_mean`. "
"`logit_bias` is deprecated, use `logit_mean` instead."
)
logger.warning(
"`logit_bias` is deprecated and will be removed in v0.21. "
"Use `logit_mean` instead."
)
self.logit_mean = self.logit_bias
self.logit_bias = None
# Handle deprecated logit_scale → logit_sigma
if self.logit_scale is not None:
if self.logit_sigma is not None:
raise ValueError(
"Cannot set both `logit_scale` and `logit_sigma`. "
"`logit_scale` is deprecated, use `logit_sigma` instead."
)
logger.warning(
"`logit_scale` is deprecated and will be removed in v0.21. "
"Use `logit_sigma` instead (logit_sigma = 1/logit_scale)."
)
if self.logit_scale == 0:
raise ValueError("logit_scale cannot be 0 (division by zero)")
self.logit_sigma = 1.0 / self.logit_scale
self.logit_scale = None
if self.logit_sigma is not None and self.logit_sigma == 0:
raise ValueError("logit_sigma cannot be 0 (division by zero)")
if pooling_type := self.pooling_type: if pooling_type := self.pooling_type:
if self.seq_pooling_type is not None: if self.seq_pooling_type is not None:
raise ValueError( raise ValueError(
......
...@@ -103,16 +103,16 @@ class ClassifierPoolerHead(SequencePoolerHead): ...@@ -103,16 +103,16 @@ class ClassifierPoolerHead(SequencePoolerHead):
def __init__( def __init__(
self, self,
classifier: ClassifierFn | None = None, classifier: ClassifierFn | None = None,
logit_bias: float | None = None, logit_mean: float | None = None,
logit_scale: float | None = None, logit_sigma: float | None = None,
head_dtype: torch.dtype | str | None = None, head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None, activation: ActivationFn | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.classifier = classifier self.classifier = classifier
self.logit_bias = logit_bias self.logit_mean = logit_mean
self.logit_scale = logit_scale self.logit_sigma = logit_sigma
self.head_dtype = head_dtype self.head_dtype = head_dtype
self.activation = activation self.activation = activation
...@@ -140,10 +140,11 @@ class ClassifierPoolerHead(SequencePoolerHead): ...@@ -140,10 +140,11 @@ class ClassifierPoolerHead(SequencePoolerHead):
logits = pooled_data logits = pooled_data
# logits shape: [batchsize, num_labels] # logits shape: [batchsize, num_labels]
if self.logit_bias is not None: # Affine score calibration: activation((logit - mean) / sigma)
logits -= self.logit_bias if self.logit_mean is not None:
if self.logit_scale is not None: logits = logits - self.logit_mean
logits *= self.logit_scale if self.logit_sigma is not None:
logits = logits / self.logit_sigma
if self.activation is not None: if self.activation is not None:
flags = [p.use_activation for p in pooling_params] flags = [p.use_activation for p in pooling_params]
......
...@@ -118,8 +118,8 @@ def pooler_for_classify( ...@@ -118,8 +118,8 @@ def pooler_for_classify(
head = ClassifierPoolerHead( head = ClassifierPoolerHead(
head_dtype=model_config.head_dtype, head_dtype=model_config.head_dtype,
classifier=classifier, classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias, logit_mean=model_config.pooler_config.logit_mean,
logit_scale=model_config.pooler_config.logit_scale, logit_sigma=model_config.pooler_config.logit_sigma,
activation=resolve_classifier_act_fn( activation=resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn model_config, static_num_labels=True, act_fn=act_fn
), ),
......
...@@ -92,16 +92,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead): ...@@ -92,16 +92,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
def __init__( def __init__(
self, self,
classifier: ClassifierFn | None = None, classifier: ClassifierFn | None = None,
logit_bias: float | None = None, logit_mean: float | None = None,
logit_scale: float | None = None, logit_sigma: float | None = None,
head_dtype: torch.dtype | str | None = None, head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None, activation: ActivationFn | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.classifier = classifier self.classifier = classifier
self.logit_bias = logit_bias self.logit_mean = logit_mean
self.logit_scale = logit_scale self.logit_sigma = logit_sigma
self.head_dtype = head_dtype self.head_dtype = head_dtype
self.activation = activation self.activation = activation
...@@ -127,10 +127,11 @@ class TokenClassifierPoolerHead(TokenPoolerHead): ...@@ -127,10 +127,11 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
logits = pooled_data logits = pooled_data
# logits shape: [n_token, num_labels] # logits shape: [n_token, num_labels]
if self.logit_bias is not None: # Affine score calibration: activation((logit - mean) / sigma)
logits -= self.logit_bias if self.logit_mean is not None:
if self.logit_scale is not None: logits = logits - self.logit_mean
logits *= self.logit_scale if self.logit_sigma is not None:
logits = logits / self.logit_sigma
if self.activation is not None and pooling_param.use_activation: if self.activation is not None and pooling_param.use_activation:
logits = self.activation(logits) logits = self.activation(logits)
......
...@@ -127,8 +127,8 @@ def pooler_for_token_classify( ...@@ -127,8 +127,8 @@ def pooler_for_token_classify(
head = TokenClassifierPoolerHead( head = TokenClassifierPoolerHead(
head_dtype=model_config.head_dtype, head_dtype=model_config.head_dtype,
classifier=classifier, classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias, logit_mean=model_config.pooler_config.logit_mean,
logit_scale=model_config.pooler_config.logit_scale, logit_sigma=model_config.pooler_config.logit_sigma,
activation=resolve_classifier_act_fn( activation=resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn model_config, static_num_labels=False, act_fn=act_fn
), ),
......
...@@ -232,8 +232,8 @@ class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig): ...@@ -232,8 +232,8 @@ class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
config = model_config.hf_config config = model_config.hf_config
config.num_labels = 1 config.num_labels = 1
pooler_config = model_config.pooler_config pooler_config = model_config.pooler_config
if pooler_config.logit_bias is None: if pooler_config.logit_mean is None:
pooler_config.logit_bias = 2.65 pooler_config.logit_mean = 2.65
class LlamaBidirectionalConfig(VerifyAndUpdateConfig): class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment