Unverified Commit f5b0846b authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix some Transformers nightly tests (#29802)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 13ea39bc
...@@ -29,7 +29,7 @@ logger = init_logger(__name__) ...@@ -29,7 +29,7 @@ logger = init_logger(__name__)
class JinaVLScorer(nn.Module): class JinaVLScorer(nn.Module):
def __init__(self, model_config: "ModelConfig"): def __init__(self, model_config: "ModelConfig"):
super().__init__() super().__init__()
config = model_config.hf_config config = model_config.hf_config.get_text_config()
head_dtype = model_config.head_dtype head_dtype = model_config.head_dtype
self.dense = ColumnParallelLinear( self.dense = ColumnParallelLinear(
config.hidden_size, config.hidden_size, params_dtype=head_dtype, bias=True config.hidden_size, config.hidden_size, params_dtype=head_dtype, bias=True
......
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.pooler import ( ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.pooler import (
PoolingParamsUpdate, PoolingParamsUpdate,
PoolingType, PoolingType,
) )
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -62,19 +62,6 @@ class ModernBertEmbeddings(nn.Module): ...@@ -62,19 +62,6 @@ class ModernBertEmbeddings(nn.Module):
return embeddings return embeddings
class ModernBertRotaryEmbedding(RotaryEmbedding):
def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: float):
super().__init__(
head_size=head_size,
rotary_dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
is_neox_style=True,
dtype=torch.float16,
)
self.config = config
class ModernBertAttention(nn.Module): class ModernBertAttention(nn.Module):
def __init__(self, config: ModernBertConfig, layer_id: int | None = None): def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
super().__init__() super().__init__()
...@@ -95,6 +82,15 @@ class ModernBertAttention(nn.Module): ...@@ -95,6 +82,15 @@ class ModernBertAttention(nn.Module):
bias=config.attention_bias, bias=config.attention_bias,
) )
if layer_types := getattr(config, "layer_types", None):
# Transformers v5
layer_type = layer_types[layer_id]
rope_parameters = config.rope_parameters[layer_type]
sliding_window: int | None = None
if layer_type == "sliding_attention":
sliding_window = config.local_attention // 2
else:
# Transformers v4
sliding_window = None sliding_window = None
if layer_id % config.global_attn_every_n_layers != 0: if layer_id % config.global_attn_every_n_layers != 0:
sliding_window = config.local_attention // 2 sliding_window = config.local_attention // 2
...@@ -105,9 +101,14 @@ class ModernBertAttention(nn.Module): ...@@ -105,9 +101,14 @@ class ModernBertAttention(nn.Module):
) )
else: else:
rope_theta = config.global_rope_theta rope_theta = config.global_rope_theta
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rotary_emb = ModernBertRotaryEmbedding( self.rotary_emb = get_rope(
config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
rope_parameters=rope_parameters,
dtype=torch.float16,
) )
self.attn = EncoderOnlyAttention( self.attn = EncoderOnlyAttention(
self.num_heads, self.num_heads,
......
...@@ -503,7 +503,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -503,7 +503,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config.get_text_config()
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment