Unverified Commit 9609327f authored by Nan Qin's avatar Nan Qin Committed by GitHub
Browse files

[Core] [Bugfix]: tensor parallel with prompt embeds (#18171)


Signed-off-by: default avatarNan2018 <nan@protopia.ai>
Co-authored-by: default avatarAndrew Sansom <andrew@protopia.ai>
parent f07a673e
...@@ -8,12 +8,13 @@ import weakref ...@@ -8,12 +8,13 @@ import weakref
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
import torch
from vllm import LLM from vllm import LLM, envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from ..conftest import VllmRunner from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
...@@ -43,11 +44,26 @@ def test_vllm_gc_ed(): ...@@ -43,11 +44,26 @@ def test_vllm_gc_ed():
assert weak_llm() is None assert weak_llm() is None
def _fix_prompt_embed_outputs(
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
example_prompts: list[str]) -> list[tuple[list[int], str]]:
fixed_vllm_outputs = []
for vllm_output, hf_input, prompt in zip(
vllm_outputs, hf_model.get_inputs(example_prompts),
example_prompts):
hf_input_ids = hf_input["input_ids"].tolist()[0]
fixed_vllm_outputs.append(
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
prompt + vllm_output[1]))
return fixed_vllm_outputs
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models( def test_models(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
...@@ -56,8 +72,13 @@ def test_models( ...@@ -56,8 +72,13 @@ def test_models(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
enable_prompt_embeds: bool,
) -> None: ) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if backend == "FLASHINFER" and current_platform.is_rocm(): if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
...@@ -78,14 +99,25 @@ def test_models( ...@@ -78,14 +99,25 @@ def test_models(
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
with VllmRunner(model, with VllmRunner(model,
max_model_len=8192, max_model_len=8192,
dtype=dtype, dtype=dtype,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model: gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, if enable_prompt_embeds:
max_tokens) vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
...@@ -108,6 +140,7 @@ def test_models( ...@@ -108,6 +140,7 @@ def test_models(
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"), ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
]) ])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed( def test_models_distributed(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
...@@ -117,14 +150,22 @@ def test_models_distributed( ...@@ -117,14 +150,22 @@ def test_models_distributed(
distributed_executor_backend: str, distributed_executor_backend: str,
attention_backend: str, attention_backend: str,
test_suite: str, test_suite: str,
enable_prompt_embeds: bool,
) -> None: ) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if test_suite != TARGET_TEST_SUITE: if test_suite != TARGET_TEST_SUITE:
pytest.skip(f"Skip test for {test_suite}") pytest.skip(f"Skip test for {test_suite}")
with monkeypatch.context() as monkeypatch_context: with monkeypatch.context() as monkeypatch_context:
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
# test Ray Compiled Graph if enable_prompt_embeds:
pytest.skip(
"enable_prompt_embeds does not work with ray compiled dag."
)
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
...@@ -147,12 +188,26 @@ def test_models_distributed( ...@@ -147,12 +188,26 @@ def test_models_distributed(
dtype=dtype, dtype=dtype,
tensor_parallel_size=2, tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, if enable_prompt_embeds:
max_tokens)
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
......
...@@ -430,6 +430,15 @@ class HfRunner: ...@@ -430,6 +430,15 @@ class HfRunner:
return all_inputs return all_inputs
def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
all_inputs = self.get_inputs(prompts)
embeddings = []
for inputs in all_inputs:
input_ids = self.wrap_device(inputs)["input_ids"]
embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
embeddings.append(embedding)
return embeddings
def classify(self, prompts: list[str]) -> list[str]: def classify(self, prompts: list[str]) -> list[str]:
# output is final logits # output is final logits
all_inputs = self.get_inputs(prompts) all_inputs = self.get_inputs(prompts)
......
...@@ -1123,7 +1123,7 @@ class SequenceOutput( ...@@ -1123,7 +1123,7 @@ class SequenceOutput(
self.output_embed.shape if self.output_embed is not None else None self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, " f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}" f"output_embed.shape={output_embed_shape}, "
f"logprobs={self.logprobs})") f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
......
...@@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState ...@@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture) graph_capture)
...@@ -872,7 +872,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -872,7 +872,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
""" """
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = list[int]() input_tokens = list[int]()
inputs_embeds_lst = list[torch.Tensor]() inputs_embeds_list = list[torch.Tensor]()
token_types = list[int]() token_types = list[int]()
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
...@@ -880,15 +880,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -880,15 +880,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for cur_token_types in inter_data.token_types: for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types) token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None: if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append( inputs_embeds_list.append(
inter_data.inputs_embeds.to( inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype, dtype=self.runner.model_config.dtype,
device=self.runner.device)) device=self.runner.device))
inputs_embeds: Optional[torch.Tensor] inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0: if len(inputs_embeds_list) == 0:
inputs_embeds = None inputs_embeds = None
else: else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
dtype=self.runner.model_config.dtype, dtype=self.runner.model_config.dtype,
device=self.runner.device) device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens) assert len(inputs_embeds) == len(input_tokens)
...@@ -1893,15 +1893,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1893,15 +1893,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
logits = self.model.compute_logits(hidden_or_intermediate_states, logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
if not self.is_driver_worker: if self.is_driver_worker:
return []
if model_input.async_callback is not None: if model_input.async_callback is not None:
model_input.async_callback() model_input.async_callback()
# Sample the next token. # Sample the next token.
assert isinstance(self.sampler, Sampler) assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None: if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True self.sampler.include_gpu_probs_tensor = True
...@@ -1919,24 +1917,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1919,24 +1917,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if intermediate_tensors is not None: if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get( orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item() "model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the latency # If there are multiple workers, we are still tracking the
# from the start time of the driver worker to the end time of the # latency from the start time of the driver worker to the end
# driver worker. The model forward time will then end up covering # time of the driver worker. The model forward time will then
# the communication time as well. # end up covering the communication time as well.
output.model_forward_time = (orig_model_forward_time + output.model_forward_time = (orig_model_forward_time +
model_forward_time) model_forward_time)
if model_input.inputs_embeds is not None: if model_input.inputs_embeds is not None:
if self.is_driver_worker:
sampled = broadcast_tensor_dict(
{"token_ids": output.sampled_token_ids})
else:
sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
sampled_token_embeds = self.model.get_input_embeddings(
sampled["token_ids"].squeeze(1))
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \ self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs_tensor orig_include_gpu_probs
if output.sampled_token_ids is not None:
output.sampled_token_embeds = self.model.get_input_embeddings( output.sampled_token_embeds = sampled_token_embeds
output.sampled_token_ids.squeeze(1))
for token_embed, sequence_group_output in zip( for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs): output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1 assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[0].output_embed = token_embed sequence_group_output.samples[
0].output_embed = token_embed
if not self.is_driver_worker:
return []
if self.return_hidden_states: if self.return_hidden_states:
# we only need to pass hidden states of most recent token # we only need to pass hidden states of most recent token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment