Unverified Commit 8f5b9910 authored by nathan's avatar nathan Committed by GitHub
Browse files

Add support for Qwen3-seq-cls (#9357)

parent ef3004d9
...@@ -42,7 +42,13 @@ class Qwen3ForSequenceClassification(nn.Module): ...@@ -42,7 +42,13 @@ class Qwen3ForSequenceClassification(nn.Module):
# Use normalize=True for qwen3 embedding based on official implementation # Use normalize=True for qwen3 embedding based on official implementation
# Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55 # Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
# Official code: output = F.normalize(output, p=2, dim=1) # Official code: output = F.normalize(output, p=2, dim=1)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) normalize = True
# We don't want to normalize the embedding if we have a classification head
if config.id2label is not None or config.label2id is not None:
normalize = False
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=normalize)
self.eos_token_id = config.eos_token_id self.eos_token_id = config.eos_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment