Unverified Commit ddcec289 authored by Andrew Sansom's avatar Andrew Sansom Committed by GitHub
Browse files

Fix implementation divergence for BLOOM models between vLLM and HuggingFace...


Fix implementation divergence for BLOOM models between vLLM and HuggingFace when using prompt embeds (#24686)
Signed-off-by: default avatarAndrew Sansom <andrew@protopia.ai>
parent e090b7b4
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional from typing import Optional
import pytest import pytest
...@@ -99,9 +98,10 @@ AITER_MODEL_LIST = [ ...@@ -99,9 +98,10 @@ AITER_MODEL_LIST = [
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str, def test_models(hf_runner, vllm_runner, example_prompts, model: str,
max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, max_tokens: int, num_logprobs: int, use_rocm_aiter: bool,
monkeypatch) -> None: use_prompt_embeds: bool, monkeypatch) -> None:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip") model_info.check_available_online(on_fail="skip")
...@@ -119,8 +119,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, ...@@ -119,8 +119,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
# in parts of the operators # in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.") pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0"
with hf_runner(model) as hf_model: with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit( hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
......
...@@ -257,7 +257,7 @@ class BloomModel(nn.Module): ...@@ -257,7 +257,7 @@ class BloomModel(nn.Module):
config.hidden_size)) config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings_layernorm(self.word_embeddings(input_ids)) return self.word_embeddings(input_ids)
def forward( def forward(
self, self,
...@@ -271,6 +271,7 @@ class BloomModel(nn.Module): ...@@ -271,6 +271,7 @@ class BloomModel(nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment