Unverified Commit 232214b2 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Replace `PoolingParams.normalize` with `use_activation` (#32243)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent eb28e806
...@@ -47,7 +47,7 @@ The key parameters for chunked processing are in the `--pooler-config`: ...@@ -47,7 +47,7 @@ The key parameters for chunked processing are in the `--pooler-config`:
```json ```json
{ {
"pooling_type": "auto", "pooling_type": "auto",
"normalize": true, "use_activation": true,
"enable_chunked_processing": true, "enable_chunked_processing": true,
"max_embed_len": 3072000 "max_embed_len": 3072000
} }
......
...@@ -14,7 +14,7 @@ Prerequisites: ...@@ -14,7 +14,7 @@ Prerequisites:
# MEAN pooling (processes all chunks, recommended for complete coverage) # MEAN pooling (processes all chunks, recommended for complete coverage)
vllm serve intfloat/multilingual-e5-large \ vllm serve intfloat/multilingual-e5-large \
--pooler-config \ --pooler-config \
'{"pooling_type": "MEAN", "normalize": true, ' \ '{"pooling_type": "MEAN", "use_activation": true, ' \
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \ '"enable_chunked_processing": true, "max_embed_len": 3072000}' \
--served-model-name multilingual-e5-large \ --served-model-name multilingual-e5-large \
--trust-remote-code \ --trust-remote-code \
...@@ -24,7 +24,7 @@ Prerequisites: ...@@ -24,7 +24,7 @@ Prerequisites:
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks) # OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
vllm serve BAAI/bge-large-en-v1.5 \ vllm serve BAAI/bge-large-en-v1.5 \
--pooler-config \ --pooler-config \
'{"pooling_type": "CLS", "normalize": true, ' \ '{"pooling_type": "CLS", "use_activation": true, ' \
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \ '"enable_chunked_processing": true, "max_embed_len": 1048576}' \
--served-model-name bge-large-en-v1.5 \ --served-model-name bge-large-en-v1.5 \
--trust-remote-code \ --trust-remote-code \
......
...@@ -96,7 +96,7 @@ echo "" ...@@ -96,7 +96,7 @@ echo ""
echo "🔧 Starting server with enhanced chunked processing configuration..." echo "🔧 Starting server with enhanced chunked processing configuration..."
# Build pooler config JSON # Build pooler config JSON
POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enable_chunked_processing\": ${VLLM_ENABLE_CHUNKED_PROCESSING}, \"max_embed_len\": ${MAX_EMBED_LEN}}" POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"use_activation\": true, \"enable_chunked_processing\": ${VLLM_ENABLE_CHUNKED_PROCESSING}, \"max_embed_len\": ${MAX_EMBED_LEN}}"
# Start vLLM server with enhanced chunked processing # Start vLLM server with enhanced chunked processing
vllm serve "$MODEL_NAME" \ vllm serve "$MODEL_NAME" \
......
...@@ -53,7 +53,9 @@ def test_token_embed(llm: LLM): ...@@ -53,7 +53,9 @@ def test_token_embed(llm: LLM):
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(normalize): def get_outputs(normalize):
outputs = llm.embed( outputs = llm.embed(
prompts, pooling_params=PoolingParams(normalize=normalize), use_tqdm=False prompts,
pooling_params=PoolingParams(use_activation=normalize),
use_tqdm=False,
) )
return torch.tensor([x.outputs.embedding for x in outputs]) return torch.tensor([x.outputs.embedding for x in outputs])
......
...@@ -216,7 +216,7 @@ def server_with_chunked_processing(): ...@@ -216,7 +216,7 @@ def server_with_chunked_processing():
"512", # Set smaller max_model_len to trigger chunking mechanism "512", # Set smaller max_model_len to trigger chunking mechanism
"--pooler-config", "--pooler-config",
( (
'{"pooling_type": "MEAN", "normalize": true, ' '{"pooling_type": "MEAN", "use_activation": true, '
'"enable_chunked_processing": true, "max_embed_len": 10000}' '"enable_chunked_processing": true, "max_embed_len": 10000}'
), ),
"--gpu-memory-utilization", "--gpu-memory-utilization",
......
...@@ -236,17 +236,14 @@ class TestModel: ...@@ -236,17 +236,14 @@ class TestModel:
"use_activation": use_activation, "use_activation": use_activation,
}, },
) )
if response.status_code != 200:
return response
outputs = response.json() outputs = response.json()
return torch.tensor([x["score"] for x in outputs["data"]]) return torch.tensor([x["score"] for x in outputs["data"]])
if model["is_cross_encoder"]:
default = get_outputs(use_activation=None) default = get_outputs(use_activation=None)
w_activation = get_outputs(use_activation=True) w_activation = get_outputs(use_activation=True)
wo_activation = get_outputs(use_activation=False) wo_activation = get_outputs(use_activation=False)
if model["is_cross_encoder"]:
assert torch.allclose(default, w_activation, atol=1e-2), ( assert torch.allclose(default, w_activation, atol=1e-2), (
"Default should use activation." "Default should use activation."
) )
...@@ -256,9 +253,3 @@ class TestModel: ...@@ -256,9 +253,3 @@ class TestModel:
assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), ( assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), (
"w_activation should be close to activation(wo_activation)." "w_activation should be close to activation(wo_activation)."
) )
else:
get_outputs(use_activation=None)
# The activation parameter only works for the is_cross_encoder model
response = get_outputs(use_activation=True)
assert response.status_code == 400
...@@ -48,7 +48,7 @@ def test_model_loading_with_params(vllm_runner, monkeypatch): ...@@ -48,7 +48,7 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
# asserts on the pooling config files # asserts on the pooling config files
assert model_config.pooler_config.seq_pooling_type == "CLS" assert model_config.pooler_config.seq_pooling_type == "CLS"
assert model_config.pooler_config.tok_pooling_type == "ALL" assert model_config.pooler_config.tok_pooling_type == "ALL"
assert model_config.pooler_config.normalize assert model_config.pooler_config.use_activation
# asserts on the tokenizer loaded # asserts on the tokenizer loaded
assert model_config.tokenizer == "BAAI/bge-base-en-v1.5" assert model_config.tokenizer == "BAAI/bge-base-en-v1.5"
...@@ -93,7 +93,7 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch): ...@@ -93,7 +93,7 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
# asserts on the pooling config files # asserts on the pooling config files
assert model_config.pooler_config.seq_pooling_type == "MEAN" assert model_config.pooler_config.seq_pooling_type == "MEAN"
assert model_config.pooler_config.tok_pooling_type == "ALL" assert model_config.pooler_config.tok_pooling_type == "ALL"
assert model_config.pooler_config.normalize assert model_config.pooler_config.use_activation
# asserts on the tokenizer loaded # asserts on the tokenizer loaded
assert model_config.tokenizer == "intfloat/multilingual-e5-base" assert model_config.tokenizer == "intfloat/multilingual-e5-base"
......
...@@ -66,7 +66,7 @@ def test_embed_models_using_normalize( ...@@ -66,7 +66,7 @@ def test_embed_models_using_normalize(
model, model,
max_model_len=512, max_model_len=512,
dtype=dtype, dtype=dtype,
pooler_config=PoolerConfig(normalize=False), pooler_config=PoolerConfig(use_activation=False),
) as vllm_model: ) as vllm_model:
wo_normalize = torch.tensor(vllm_model.embed(example_prompts)) wo_normalize = torch.tensor(vllm_model.embed(example_prompts))
...@@ -74,7 +74,7 @@ def test_embed_models_using_normalize( ...@@ -74,7 +74,7 @@ def test_embed_models_using_normalize(
model, model,
max_model_len=512, max_model_len=512,
dtype=dtype, dtype=dtype,
pooler_config=PoolerConfig(normalize=True), pooler_config=PoolerConfig(use_activation=True),
) as vllm_model: ) as vllm_model:
w_normalize = torch.tensor(vllm_model.embed(example_prompts)) w_normalize = torch.tensor(vllm_model.embed(example_prompts))
...@@ -146,7 +146,7 @@ def test_multi_vector_retrieval_models_using_normalize( ...@@ -146,7 +146,7 @@ def test_multi_vector_retrieval_models_using_normalize(
model, model,
max_model_len=512, max_model_len=512,
dtype=dtype, dtype=dtype,
pooler_config=PoolerConfig(normalize=False), pooler_config=PoolerConfig(use_activation=False),
) as vllm_model: ) as vllm_model:
wo_normalize = vllm_model.token_embed(example_prompts) wo_normalize = vllm_model.token_embed(example_prompts)
...@@ -154,7 +154,7 @@ def test_multi_vector_retrieval_models_using_normalize( ...@@ -154,7 +154,7 @@ def test_multi_vector_retrieval_models_using_normalize(
model, model,
max_model_len=512, max_model_len=512,
dtype=dtype, dtype=dtype,
pooler_config=PoolerConfig(normalize=True), pooler_config=PoolerConfig(use_activation=True),
) as vllm_model: ) as vllm_model:
w_normalize = vllm_model.token_embed(example_prompts) w_normalize = vllm_model.token_embed(example_prompts)
......
...@@ -162,7 +162,7 @@ def test_get_pooling_config(): ...@@ -162,7 +162,7 @@ def test_get_pooling_config():
model_config = ModelConfig(model_id) model_config = ModelConfig(model_id)
assert model_config.pooler_config is not None assert model_config.pooler_config is not None
assert model_config.pooler_config.normalize assert model_config.pooler_config.use_activation
assert model_config.pooler_config.seq_pooling_type == "MEAN" assert model_config.pooler_config.seq_pooling_type == "MEAN"
assert model_config.pooler_config.tok_pooling_type == "ALL" assert model_config.pooler_config.tok_pooling_type == "ALL"
......
...@@ -18,7 +18,7 @@ EMBEDDING_MODELS = [ ...@@ -18,7 +18,7 @@ EMBEDDING_MODELS = [
] ]
classify_parameters = ["use_activation"] classify_parameters = ["use_activation"]
embed_parameters = ["dimensions", "normalize"] embed_parameters = ["dimensions", "use_activation"]
step_pooling_parameters = ["step_tag_id", "returned_token_ids"] step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
...@@ -42,17 +42,17 @@ def test_embed(): ...@@ -42,17 +42,17 @@ def test_embed():
task = "embed" task = "embed"
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS")) model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(normalize=None) pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=True) pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=False) pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = classify_parameters + step_pooling_parameters invalid_parameters = classify_parameters + step_pooling_parameters
for p in invalid_parameters: for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -98,7 +98,7 @@ def test_classify(task): ...@@ -98,7 +98,7 @@ def test_classify(task):
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = embed_parameters + step_pooling_parameters invalid_parameters = embed_parameters + step_pooling_parameters
for p in invalid_parameters: for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -111,20 +111,20 @@ def test_token_embed(pooling_type: str): ...@@ -111,20 +111,20 @@ def test_token_embed(pooling_type: str):
pooler_config=PoolerConfig(tok_pooling_type=pooling_type) pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
) )
pooling_params = PoolingParams(normalize=None) pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=True) pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=False) pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = classify_parameters invalid_parameters = classify_parameters
if pooling_type != "STEP": if pooling_type != "STEP":
invalid_parameters = classify_parameters + step_pooling_parameters invalid_parameters = classify_parameters + step_pooling_parameters
for p in invalid_parameters: for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -150,7 +150,7 @@ def test_token_classify(pooling_type: str): ...@@ -150,7 +150,7 @@ def test_token_classify(pooling_type: str):
if pooling_type != "STEP": if pooling_type != "STEP":
invalid_parameters = embed_parameters + step_pooling_parameters invalid_parameters = embed_parameters + step_pooling_parameters
for p in invalid_parameters: for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -48,7 +48,7 @@ class PoolerConfig: ...@@ -48,7 +48,7 @@ class PoolerConfig:
## for embeddings models ## for embeddings models
normalize: bool | None = None normalize: bool | None = None
""" """
Whether to normalize the embeddings outputs. Defaults to True. DEPRECATED: please use `use_activation` instead.
""" """
dimensions: int | None = None dimensions: int | None = None
""" """
...@@ -75,11 +75,11 @@ class PoolerConfig: ...@@ -75,11 +75,11 @@ class PoolerConfig:
## for classification models ## for classification models
softmax: float | None = None softmax: float | None = None
""" """
softmax will be deprecated, please use use_activation instead. DEPRECATED: please use `use_activation` instead.
""" """
activation: float | None = None activation: float | None = None
""" """
activation will be deprecated, please use use_activation instead. DEPRECATED: please use `use_activation` instead.
""" """
use_activation: bool | None = None use_activation: bool | None = None
""" """
...@@ -164,17 +164,24 @@ class PoolerConfig: ...@@ -164,17 +164,24 @@ class PoolerConfig:
def get_use_activation(o: object): def get_use_activation(o: object):
if softmax := getattr(o, "softmax", None) is not None: if (normalize := getattr(o, "normalize", None)) is not None:
logger.warning_once( logger.warning_once(
"softmax will be deprecated and will be removed in v0.15. " "`normalize` is deprecated and will be removed in v0.15. "
"Please use use_activation instead." "Please use `use_activation` instead."
)
return normalize
if (softmax := getattr(o, "softmax", None)) is not None:
logger.warning_once(
"`softmax` is deprecated and will be removed in v0.15. "
"Please use `use_activation` instead."
) )
return softmax return softmax
if activation := getattr(o, "activation", None) is not None: if (activation := getattr(o, "activation", None)) is not None:
logger.warning_once( logger.warning_once(
"activation will be deprecated and will be removed in v0.15. " "`activation` is deprecated and will be removed in v0.15. "
"Please use use_activation instead." "Please use `use_activation` instead."
) )
return activation return activation
......
...@@ -75,7 +75,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): ...@@ -75,7 +75,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions, dimensions=self.dimensions,
normalize=self.normalize, use_activation=self.normalize,
) )
...@@ -189,7 +189,7 @@ class EmbeddingChatRequest(OpenAIBaseModel): ...@@ -189,7 +189,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions, dimensions=self.dimensions,
normalize=self.normalize, use_activation=self.normalize,
) )
......
...@@ -40,7 +40,6 @@ class PoolingCompletionRequest(EmbeddingCompletionRequest): ...@@ -40,7 +40,6 @@ class PoolingCompletionRequest(EmbeddingCompletionRequest):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions, dimensions=self.dimensions,
normalize=self.normalize,
use_activation=get_use_activation(self), use_activation=get_use_activation(self),
) )
...@@ -66,7 +65,6 @@ class PoolingChatRequest(EmbeddingChatRequest): ...@@ -66,7 +65,6 @@ class PoolingChatRequest(EmbeddingChatRequest):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions, dimensions=self.dimensions,
normalize=self.normalize,
use_activation=get_use_activation(self), use_activation=get_use_activation(self),
) )
......
...@@ -83,7 +83,7 @@ class EmbeddingPoolerHead(SequencePoolerHead): ...@@ -83,7 +83,7 @@ class EmbeddingPoolerHead(SequencePoolerHead):
# for normalize # for normalize
if self.activation is not None: if self.activation is not None:
flags = [p.normalize for p in pooling_params] flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1: if len(set(flags)) == 1:
if flags[0]: if flags[0]:
pooled_data = self.activation(pooled_data) pooled_data = self.activation(pooled_data)
......
...@@ -95,8 +95,8 @@ def pooler_for_embed(pooler_config: PoolerConfig): ...@@ -95,8 +95,8 @@ def pooler_for_embed(pooler_config: PoolerConfig):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config model_config = vllm_config.model_config
head = EmbeddingPoolerHead( head = EmbeddingPoolerHead(
projector=_load_st_projector(model_config),
head_dtype=model_config.head_dtype, head_dtype=model_config.head_dtype,
projector=_load_st_projector(model_config),
activation=PoolerNormalize(), activation=PoolerNormalize(),
) )
...@@ -116,9 +116,9 @@ def pooler_for_classify( ...@@ -116,9 +116,9 @@ def pooler_for_classify(
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config model_config = vllm_config.model_config
head = ClassifierPoolerHead( head = ClassifierPoolerHead(
head_dtype=model_config.head_dtype,
classifier=classifier, classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias, logit_bias=model_config.pooler_config.logit_bias,
head_dtype=model_config.head_dtype,
activation=resolve_classifier_act_fn( activation=resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn model_config, static_num_labels=True, act_fn=act_fn
), ),
......
...@@ -44,14 +44,14 @@ class TokenPoolerHead(nn.Module, ABC): ...@@ -44,14 +44,14 @@ class TokenPoolerHead(nn.Module, ABC):
class TokenEmbeddingPoolerHead(TokenPoolerHead): class TokenEmbeddingPoolerHead(TokenPoolerHead):
def __init__( def __init__(
self, self,
projector: ProjectorFn | None = None,
head_dtype: torch.dtype | str | None = None, head_dtype: torch.dtype | str | None = None,
projector: ProjectorFn | None = None,
activation: ActivationFn | None = None, activation: ActivationFn | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.projector = projector
self.head_dtype = head_dtype self.head_dtype = head_dtype
self.projector = projector
self.activation = activation self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
...@@ -79,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead): ...@@ -79,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
pooled_data = pooled_data[..., : pooling_param.dimensions] pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize # for normalize
if self.activation is not None and pooling_param.normalize: if self.activation is not None and pooling_param.use_activation:
pooled_data = self.activation(pooled_data) pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension] # pooled_data shape: [n_tokens, embedding_dimension]
......
...@@ -95,8 +95,8 @@ def pooler_for_token_embed(pooler_config: PoolerConfig): ...@@ -95,8 +95,8 @@ def pooler_for_token_embed(pooler_config: PoolerConfig):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config model_config = vllm_config.model_config
head = TokenEmbeddingPoolerHead( head = TokenEmbeddingPoolerHead(
projector=_load_st_projector(model_config),
head_dtype=model_config.head_dtype, head_dtype=model_config.head_dtype,
projector=_load_st_projector(model_config),
activation=PoolerNormalize(), activation=PoolerNormalize(),
) )
...@@ -116,9 +116,9 @@ def pooler_for_token_classify( ...@@ -116,9 +116,9 @@ def pooler_for_token_classify(
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config model_config = vllm_config.model_config
head = TokenClassifierPoolerHead( head = TokenClassifierPoolerHead(
head_dtype=model_config.head_dtype,
classifier=classifier, classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias, logit_bias=model_config.pooler_config.logit_bias,
head_dtype=model_config.head_dtype,
activation=resolve_classifier_act_fn( activation=resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn model_config, static_num_labels=False, act_fn=act_fn
), ),
......
...@@ -116,8 +116,8 @@ class BertPooler(SequencePooler): ...@@ -116,8 +116,8 @@ class BertPooler(SequencePooler):
# Use lambdas so that weights are not registered under `self.head` # Use lambdas so that weights are not registered under `self.head`
self.head = EmbeddingPoolerHead( self.head = EmbeddingPoolerHead(
projector=lambda x: self.dense(x),
head_dtype=head_dtype, head_dtype=head_dtype,
projector=lambda x: self.dense(x),
activation=LambdaPoolerActivation(self.act_fn), activation=LambdaPoolerActivation(self.act_fn),
) )
......
...@@ -309,12 +309,13 @@ class ModernBertPooler(SequencePooler): ...@@ -309,12 +309,13 @@ class ModernBertPooler(SequencePooler):
config.hidden_size, config.hidden_size,
eps=config.norm_eps, eps=config.norm_eps,
bias=config.norm_bias, bias=config.norm_bias,
dtype=head_dtype,
) )
# Use lambdas so that weights are not registered under `self.head` # Use lambdas so that weights are not registered under `self.head`
self.head = EmbeddingPoolerHead( self.head = EmbeddingPoolerHead(
projector=lambda x: self.dense(x),
head_dtype=head_dtype, head_dtype=head_dtype,
projector=lambda x: self.dense(x),
activation=LambdaPoolerActivation(lambda x: self.norm(self.act(x))), activation=LambdaPoolerActivation(lambda x: self.norm(self.act(x))),
) )
......
...@@ -26,9 +26,9 @@ class PoolingParams( ...@@ -26,9 +26,9 @@ class PoolingParams(
Set to None to disable truncation. Set to None to disable truncation.
dimensions: Reduce the dimensions of embeddings dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation. if model support matryoshka representation.
normalize: Whether to normalize the embeddings outputs. normalize: Deprecated, please use use_activation instead.
softmax: softmax will be deprecated, please use use_activation instead. softmax: Deprecated, please use use_activation instead.
activation: activation will be deprecated, please use use_activation instead. activation: Deprecated, please use use_activation instead.
use_activation: Whether to apply activation function to use_activation: Whether to apply activation function to
the classification outputs. the classification outputs.
""" """
...@@ -63,15 +63,15 @@ class PoolingParams( ...@@ -63,15 +63,15 @@ class PoolingParams(
@property @property
def all_parameters(self) -> list[str]: def all_parameters(self) -> list[str]:
return ["dimensions", "normalize", "use_activation"] return ["dimensions", "use_activation"]
@property @property
def valid_parameters(self): def valid_parameters(self):
return { return {
"embed": ["dimensions", "normalize"], "embed": ["dimensions", "use_activation"],
"classify": ["use_activation"], "classify": ["use_activation"],
"score": ["use_activation"], "score": ["use_activation"],
"token_embed": ["dimensions", "normalize"], "token_embed": ["dimensions", "use_activation"],
"token_classify": ["use_activation"], "token_classify": ["use_activation"],
} }
...@@ -162,8 +162,8 @@ class PoolingParams( ...@@ -162,8 +162,8 @@ class PoolingParams(
def _set_default_parameters(self, model_config: Optional["ModelConfig"]): def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
if self.task in ["embed", "token_embed"]: if self.task in ["embed", "token_embed"]:
if self.normalize is None: if self.use_activation is None:
self.normalize = True self.use_activation = True
if self.dimensions is not None and model_config is not None: if self.dimensions is not None and model_config is not None:
if not model_config.is_matryoshka: if not model_config.is_matryoshka:
...@@ -213,7 +213,6 @@ class PoolingParams( ...@@ -213,7 +213,6 @@ class PoolingParams(
return ( return (
f"PoolingParams(" f"PoolingParams("
f"task={self.task}, " f"task={self.task}, "
f"normalize={self.normalize}, "
f"dimensions={self.dimensions}, " f"dimensions={self.dimensions}, "
f"use_activation={self.use_activation}, " f"use_activation={self.use_activation}, "
f"step_tag_id={self.step_tag_id}, " f"step_tag_id={self.step_tag_id}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment