Unverified Commit c9479b29 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix the failing gte embedding test (#18720)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent 6f290940
...@@ -311,6 +311,7 @@ class HfRunner: ...@@ -311,6 +311,7 @@ class HfRunner:
dtype: str = "auto", dtype: str = "auto",
*, *,
model_kwargs: Optional[dict[str, Any]] = None, model_kwargs: Optional[dict[str, Any]] = None,
trust_remote_code: bool = True,
is_sentence_transformer: bool = False, is_sentence_transformer: bool = False,
is_cross_encoder: bool = False, is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
...@@ -320,7 +321,7 @@ class HfRunner: ...@@ -320,7 +321,7 @@ class HfRunner:
self.config = AutoConfig.from_pretrained( self.config = AutoConfig.from_pretrained(
model_name, model_name,
trust_remote_code=True, trust_remote_code=trust_remote_code,
) )
self.device = self.get_default_device() self.device = self.get_default_device()
self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype) self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype)
...@@ -336,7 +337,7 @@ class HfRunner: ...@@ -336,7 +337,7 @@ class HfRunner:
model_name, model_name,
device=self.device, device=self.device,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
trust_remote_code=True, trust_remote_code=trust_remote_code,
) )
elif is_cross_encoder: elif is_cross_encoder:
# Lazy init required for AMD CI # Lazy init required for AMD CI
...@@ -346,12 +347,12 @@ class HfRunner: ...@@ -346,12 +347,12 @@ class HfRunner:
model_name, model_name,
device=self.device, device=self.device,
automodel_args=model_kwargs, automodel_args=model_kwargs,
trust_remote_code=True, trust_remote_code=trust_remote_code,
) )
else: else:
model = auto_cls.from_pretrained( model = auto_cls.from_pretrained(
model_name, model_name,
trust_remote_code=True, trust_remote_code=trust_remote_code,
**model_kwargs, **model_kwargs,
) )
...@@ -372,7 +373,7 @@ class HfRunner: ...@@ -372,7 +373,7 @@ class HfRunner:
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
model_name, model_name,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=True, trust_remote_code=trust_remote_code,
) )
# don't put this import at the top level # don't put this import at the top level
...@@ -381,7 +382,7 @@ class HfRunner: ...@@ -381,7 +382,7 @@ class HfRunner:
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
model_name, model_name,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=True, trust_remote_code=trust_remote_code,
) )
if skip_tokenizer_init: if skip_tokenizer_init:
self.tokenizer = self.processor.tokenizer self.tokenizer = self.processor.tokenizer
......
...@@ -10,18 +10,22 @@ from ...utils import check_embeddings_close ...@@ -10,18 +10,22 @@ from ...utils import check_embeddings_close
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
# [Encoder-only] # Be careful of the order of models, decoder-only models should be
pytest.param("BAAI/bge-base-en-v1.5", # placed before encoder-only models, otherwise `Qwen2.5-0.5B-Instruct`
marks=[pytest.mark.core_model, pytest.mark.cpu_model]), # case won't pass because gte-Qwen2-1.5B-instruct will cache custom
pytest.param("sentence-transformers/all-MiniLM-L12-v2"), # model code with bidirectional attention.
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
# [Decoder-only] # [Decoder-only]
pytest.param("BAAI/bge-multilingual-gemma2", pytest.param("BAAI/bge-multilingual-gemma2",
marks=[pytest.mark.core_model]), marks=[pytest.mark.core_model]),
pytest.param("intfloat/e5-mistral-7b-instruct", pytest.param("intfloat/e5-mistral-7b-instruct",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]), marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
# [Encoder-only]
pytest.param("BAAI/bge-base-en-v1.5",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
# [Cross-Encoder] # [Cross-Encoder]
pytest.param("sentence-transformers/stsb-roberta-base-v2"), pytest.param("sentence-transformers/stsb-roberta-base-v2"),
], ],
...@@ -44,7 +48,7 @@ def test_models( ...@@ -44,7 +48,7 @@ def test_models(
vllm_extra_kwargs = {} vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base": if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["override_pooler_config"] = \ vllm_extra_kwargs["override_pooler_config"] = \
PoolerConfig(pooling_type="MEAN") PoolerConfig(pooling_type="MEAN", normalize=False)
# The example_prompts has ending "\n", for example: # The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n" # "Write a short story about a robot that dreams for the first time.\n"
......
...@@ -45,6 +45,7 @@ MODELS = [ ...@@ -45,6 +45,7 @@ MODELS = [
########### Qwen2ForCausalLM ########### Qwen2ForCausalLM
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
architecture="Qwen2ForCausalLM", architecture="Qwen2ForCausalLM",
dtype="float32",
enable_test=True), enable_test=True),
########## ModernBertModel ########## ModernBertModel
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
......
...@@ -314,6 +314,7 @@ def check_embeddings_close( ...@@ -314,6 +314,7 @@ def check_embeddings_close(
dim=0) dim=0)
fail_msg = (f"Test{prompt_idx}:" fail_msg = (f"Test{prompt_idx}:"
f"\nCosine similarity: \t{sim:.4f}"
f"\n{name_0}:\t{embeddings_0[:16]!r}" f"\n{name_0}:\t{embeddings_0[:16]!r}"
f"\n{name_1}:\t{embeddings_1[:16]!r}") f"\n{name_1}:\t{embeddings_1[:16]!r}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment