Unverified Commit 97286a20 authored by zhrrr's avatar zhrrr Committed by GitHub
Browse files

[Model Runner V2] support dp & ep for spec decoding (#35294)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent 12b38c0f
...@@ -57,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables ...@@ -57,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import ( from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
get_cudagraph_and_dp_padding,
make_num_tokens_across_dp,
)
from vllm.v1.worker.gpu.input_batch import ( from vllm.v1.worker.gpu.input_batch import (
InputBatch, InputBatch,
InputBuffers, InputBuffers,
...@@ -265,7 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -265,7 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prepare_communication_buffer_for_model(self.model) prepare_communication_buffer_for_model(self.model)
if self.speculator is not None: if self.speculator is not None:
prepare_communication_buffer_for_model(self.speculator) prepare_communication_buffer_for_model(self.speculator.model)
# Initialize the components that require the model. # Initialize the components that require the model.
self.model_state = init_model_state( self.model_state = init_model_state(
...@@ -382,8 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -382,8 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None return None, None
assert self.execute_model_state is not None assert self.execute_model_state is not None
input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state (
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
self.execute_model_state = None self.execute_model_state = None
# dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None:
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
last_hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
num_rejected=torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
last_sampled=self.req_states.last_sampled_tokens,
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
)
assert hidden_states is not None # Last PP rank always has hidden_states assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states return hidden_states, sample_hidden_states
...@@ -431,17 +461,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -431,17 +461,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
self._dummy_pooler_run(hidden_states) self._dummy_pooler_run(hidden_states)
if self.speculator is not None:
num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens
)
self.speculator.run_model(
self.max_num_tokens,
attn_metadata=None,
slot_mappings=None,
num_tokens_across_dp=num_tokens_across_dp,
)
torch.cuda.synchronize() torch.cuda.synchronize()
del hidden_states, sample_hidden_states del hidden_states, sample_hidden_states
gc.collect() gc.collect()
...@@ -979,6 +998,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -979,6 +998,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, hidden_states,
aux_hidden_states, aux_hidden_states,
kv_connector_output, kv_connector_output,
num_tokens_across_dp,
) )
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
...@@ -1005,6 +1025,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1005,6 +1025,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, hidden_states,
aux_hidden_states, aux_hidden_states,
kv_connector_output, kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state ) = self.execute_model_state
self.execute_model_state = None self.execute_model_state = None
...@@ -1078,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1078,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens, self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu, self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
) )
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens) self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
......
...@@ -55,6 +55,26 @@ class EagleCudaGraphManager: ...@@ -55,6 +55,26 @@ class EagleCudaGraphManager:
def get_cudagraph_size(self, num_tokens: int) -> int | None: def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens) return self.cudagraph_sizes.get(num_tokens)
def get_cudagraph_runtime_mode(
self, num_tokens: int
) -> tuple[CUDAGraphMode, int | None]:
cudagraph_size = self.get_cudagraph_size(num_tokens)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.cudagraph_mode
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def capture_graph( def capture_graph(
self, self,
num_tokens: int, num_tokens: int,
......
...@@ -16,6 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -16,6 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer, build_slot_mappings_by_layer,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
...@@ -48,6 +49,10 @@ class EagleSpeculator: ...@@ -48,6 +49,10 @@ class EagleSpeculator:
self.vocab_size = self.draft_model_config.get_vocab_size() self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
# DP configuration
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.input_buffers = InputBuffers( self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens, max_num_tokens=self.max_num_tokens,
...@@ -122,8 +127,8 @@ class EagleSpeculator: ...@@ -122,8 +127,8 @@ class EagleSpeculator:
self, self,
num_reqs: int, num_reqs: int,
num_tokens_padded: int, num_tokens_padded: int,
attn_metadata: dict[str, Any], attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor], slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> None: ) -> None:
...@@ -164,6 +169,7 @@ class EagleSpeculator: ...@@ -164,6 +169,7 @@ class EagleSpeculator:
self.hidden_states, self.hidden_states,
self.max_model_len, self.max_model_len,
) )
if attn_metadata is not None:
self.block_tables.compute_slot_mappings( self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos idx_mapping, query_start_loc, pos
) )
...@@ -203,6 +209,9 @@ class EagleSpeculator: ...@@ -203,6 +209,9 @@ class EagleSpeculator:
temperature: torch.Tensor, temperature: torch.Tensor,
# [max_num_reqs] # [max_num_reqs]
seeds: torch.Tensor, seeds: torch.Tensor,
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and # number of rejected tokens, we maintain the size of eagle's input_ids and
...@@ -236,7 +245,7 @@ class EagleSpeculator: ...@@ -236,7 +245,7 @@ class EagleSpeculator:
num_tokens, num_tokens,
attn_metadata, attn_metadata,
slot_mappings, slot_mappings,
num_tokens_across_dp=None, # FIXME num_tokens_across_dp=num_tokens_across_dp,
) )
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
...@@ -282,27 +291,42 @@ class EagleSpeculator: ...@@ -282,27 +291,42 @@ class EagleSpeculator:
self.max_model_len, self.max_model_len,
self.max_num_reqs, self.max_num_reqs,
) )
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings( slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos idx_mapping, query_start_loc, pos
) )
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) cudagraph_mode, cudagraph_size = (
cudagraph_mode = self.cudagraph_manager.cudagraph_mode self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL: )
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding(
num_reqs,
cudagraph_size,
cudagraph_mode.value,
self.dp_size,
self.dp_rank,
)
)
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
if cudagraph_mode == CUDAGraphMode.FULL:
# Run full CUDA graph. # Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(cudagraph_size) self.cudagraph_manager.run_fullgraph(num_tokens_padded)
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
# Run eager or piecewise CUDA graph. # Run eager or piecewise CUDA graph.
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs attn_metadata_updated = None
slot_mappings_updated = None
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc_cpu = torch.arange( query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu" num_reqs + 1, dtype=torch.int32, device="cpu"
) )
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!! # FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata( attn_metadata_updated = build_attn_metadata(
attn_groups=self.attn_groups, attn_groups=self.attn_groups,
num_reqs=num_reqs, num_reqs=num_reqs,
num_tokens=num_reqs, num_tokens=num_reqs,
...@@ -315,15 +339,16 @@ class EagleSpeculator: ...@@ -315,15 +339,16 @@ class EagleSpeculator:
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
) )
slot_mappings_by_layer = build_slot_mappings_by_layer( slot_mappings_updated = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config slot_mappings, self.kv_cache_config
) )
self.generate_draft( self.generate_draft(
num_reqs, num_reqs,
num_tokens_padded, num_tokens_padded,
attn_metadata, attn_metadata_updated,
slot_mappings_by_layer, slot_mappings_updated,
num_tokens_across_dp=None, # FIXME num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode, cudagraph_runtime_mode=cudagraph_mode,
) )
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment