Unverified Commit 9609327f authored by Nan Qin's avatar Nan Qin Committed by GitHub
Browse files

[Core] [Bugfix]: tensor parallel with prompt embeds (#18171)


Signed-off-by: default avatarNan2018 <nan@protopia.ai>
Co-authored-by: default avatarAndrew Sansom <andrew@protopia.ai>
parent f07a673e
...@@ -8,12 +8,13 @@ import weakref ...@@ -8,12 +8,13 @@ import weakref
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
import torch
from vllm import LLM from vllm import LLM, envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from ..conftest import VllmRunner from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
...@@ -43,11 +44,26 @@ def test_vllm_gc_ed(): ...@@ -43,11 +44,26 @@ def test_vllm_gc_ed():
assert weak_llm() is None assert weak_llm() is None
def _fix_prompt_embed_outputs(
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
example_prompts: list[str]) -> list[tuple[list[int], str]]:
fixed_vllm_outputs = []
for vllm_output, hf_input, prompt in zip(
vllm_outputs, hf_model.get_inputs(example_prompts),
example_prompts):
hf_input_ids = hf_input["input_ids"].tolist()[0]
fixed_vllm_outputs.append(
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
prompt + vllm_output[1]))
return fixed_vllm_outputs
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models( def test_models(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
...@@ -56,8 +72,13 @@ def test_models( ...@@ -56,8 +72,13 @@ def test_models(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
enable_prompt_embeds: bool,
) -> None: ) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if backend == "FLASHINFER" and current_platform.is_rocm(): if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
...@@ -78,14 +99,25 @@ def test_models( ...@@ -78,14 +99,25 @@ def test_models(
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
with VllmRunner(model, with VllmRunner(model,
max_model_len=8192, max_model_len=8192,
dtype=dtype, dtype=dtype,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model: gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, if enable_prompt_embeds:
max_tokens) vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
...@@ -108,6 +140,7 @@ def test_models( ...@@ -108,6 +140,7 @@ def test_models(
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"), ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
]) ])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed( def test_models_distributed(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
...@@ -117,14 +150,22 @@ def test_models_distributed( ...@@ -117,14 +150,22 @@ def test_models_distributed(
distributed_executor_backend: str, distributed_executor_backend: str,
attention_backend: str, attention_backend: str,
test_suite: str, test_suite: str,
enable_prompt_embeds: bool,
) -> None: ) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if test_suite != TARGET_TEST_SUITE: if test_suite != TARGET_TEST_SUITE:
pytest.skip(f"Skip test for {test_suite}") pytest.skip(f"Skip test for {test_suite}")
with monkeypatch.context() as monkeypatch_context: with monkeypatch.context() as monkeypatch_context:
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
# test Ray Compiled Graph if enable_prompt_embeds:
pytest.skip(
"enable_prompt_embeds does not work with ray compiled dag."
)
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
...@@ -147,12 +188,26 @@ def test_models_distributed( ...@@ -147,12 +188,26 @@ def test_models_distributed(
dtype=dtype, dtype=dtype,
tensor_parallel_size=2, tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, if enable_prompt_embeds:
max_tokens) with hf_runner(model, dtype=dtype) as hf_model:
with torch.no_grad():
with hf_runner(model, dtype=dtype) as hf_model: prompt_embeds = hf_model.get_prompt_embeddings(
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) example_prompts)
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
......
...@@ -430,6 +430,15 @@ class HfRunner: ...@@ -430,6 +430,15 @@ class HfRunner:
return all_inputs return all_inputs
def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
all_inputs = self.get_inputs(prompts)
embeddings = []
for inputs in all_inputs:
input_ids = self.wrap_device(inputs)["input_ids"]
embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
embeddings.append(embedding)
return embeddings
def classify(self, prompts: list[str]) -> list[str]: def classify(self, prompts: list[str]) -> list[str]:
# output is final logits # output is final logits
all_inputs = self.get_inputs(prompts) all_inputs = self.get_inputs(prompts)
......
...@@ -112,12 +112,12 @@ class RequestMetrics: ...@@ -112,12 +112,12 @@ class RequestMetrics:
will include model forward, block/sync across will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time. workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from each position; the first token is from
the target model and is always accepted; the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req, e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step. tokens at 1st, 2nd, 3rd speculation step.
""" """
arrival_time: float arrival_time: float
last_token_time: float last_token_time: float
...@@ -714,9 +714,9 @@ class SequenceGroup: ...@@ -714,9 +714,9 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request. priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate model; equal to max number of tokens a step can generate
for single-draft speculative decoding but larger than for single-draft speculative decoding but larger than
that for multi-draft SD (currently not supported). that for multi-draft SD (currently not supported).
""" """
...@@ -1123,7 +1123,7 @@ class SequenceOutput( ...@@ -1123,7 +1123,7 @@ class SequenceOutput(
self.output_embed.shape if self.output_embed is not None else None self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, " f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}" f"output_embed.shape={output_embed_shape}, "
f"logprobs={self.logprobs})") f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
......
...@@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState ...@@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture) graph_capture)
...@@ -872,7 +872,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -872,7 +872,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
""" """
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = list[int]() input_tokens = list[int]()
inputs_embeds_lst = list[torch.Tensor]() inputs_embeds_list = list[torch.Tensor]()
token_types = list[int]() token_types = list[int]()
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
...@@ -880,15 +880,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -880,15 +880,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for cur_token_types in inter_data.token_types: for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types) token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None: if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append( inputs_embeds_list.append(
inter_data.inputs_embeds.to( inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype, dtype=self.runner.model_config.dtype,
device=self.runner.device)) device=self.runner.device))
inputs_embeds: Optional[torch.Tensor] inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0: if len(inputs_embeds_list) == 0:
inputs_embeds = None inputs_embeds = None
else: else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
dtype=self.runner.model_config.dtype, dtype=self.runner.model_config.dtype,
device=self.runner.device) device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens) assert len(inputs_embeds) == len(input_tokens)
...@@ -1893,50 +1893,60 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1893,50 +1893,60 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
logits = self.model.compute_logits(hidden_or_intermediate_states, logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
if not self.is_driver_worker: if self.is_driver_worker:
return [] if model_input.async_callback is not None:
model_input.async_callback()
if model_input.async_callback is not None: # Sample the next token.
model_input.async_callback() assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
# Sample the next token. output: SamplerOutput = self.sampler(
assert isinstance(self.sampler, Sampler) logits=logits,
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor sampling_metadata=model_input.sampling_metadata,
if model_input.inputs_embeds is not None: )
self.sampler.include_gpu_probs_tensor = True if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
output: SamplerOutput = self.sampler( and output is not None):
logits=logits, model_forward_end.synchronize()
sampling_metadata=model_input.sampling_metadata, model_forward_time = model_forward_start.elapsed_time(
) model_forward_end)
if (self.observability_config is not None orig_model_forward_time = 0.0
and self.observability_config.collect_model_forward_time if intermediate_tensors is not None:
and output is not None): orig_model_forward_time = intermediate_tensors.tensors.get(
model_forward_end.synchronize() "model_forward_time", torch.tensor(0.0)).item()
model_forward_time = model_forward_start.elapsed_time( # If there are multiple workers, we are still tracking the
model_forward_end) # latency from the start time of the driver worker to the end
orig_model_forward_time = 0.0 # time of the driver worker. The model forward time will then
if intermediate_tensors is not None: # end up covering the communication time as well.
orig_model_forward_time = intermediate_tensors.tensors.get( output.model_forward_time = (orig_model_forward_time +
"model_forward_time", torch.tensor(0.0)).item() model_forward_time)
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if model_input.inputs_embeds is not None: if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = \ if self.is_driver_worker:
orig_include_gpu_probs_tensor sampled = broadcast_tensor_dict(
if output.sampled_token_ids is not None: {"token_ids": output.sampled_token_ids})
output.sampled_token_embeds = self.model.get_input_embeddings( else:
output.sampled_token_ids.squeeze(1)) sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
for token_embed, sequence_group_output in zip( sampled_token_embeds = self.model.get_input_embeddings(
output.sampled_token_embeds, output.outputs): sampled["token_ids"].squeeze(1))
assert len(sequence_group_output.samples) == 1 if self.is_driver_worker:
sequence_group_output.samples[0].output_embed = token_embed self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs
output.sampled_token_embeds = sampled_token_embeds
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[
0].output_embed = token_embed
if not self.is_driver_worker:
return []
if self.return_hidden_states: if self.return_hidden_states:
# we only need to pass hidden states of most recent token # we only need to pass hidden states of most recent token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment