Commit 5ac1a8e6 authored by Roger Wang's avatar Roger Wang Committed by simon-mo
Browse files

[Bugfix] Fix interface for Olmo2 on V1 (#14976)


Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
parent 37e38061
...@@ -42,7 +42,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -42,7 +42,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -283,17 +283,19 @@ class Olmo2Model(nn.Module): ...@@ -283,17 +283,19 @@ class Olmo2Model(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
""" """
:param input_ids: A tensor of shape `(batch_size, seq_len)`. :param input_ids: A tensor of shape `(batch_size, seq_len)`.
""" """
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
# Get embeddings of input. # Get embeddings of input.
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
inputs_embeds = self.embed_tokens(input_ids) else:
hidden_states = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -337,7 +339,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP): ...@@ -337,7 +339,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -346,11 +348,13 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP): ...@@ -346,11 +348,13 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
) )
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment