Unverified Commit 97286a20 authored by zhrrr's avatar zhrrr Committed by GitHub
Browse files

[Model Runner V2] support dp & ep for spec decoding (#35294)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent 12b38c0f
...@@ -57,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables ...@@ -57,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import ( from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
get_cudagraph_and_dp_padding,
make_num_tokens_across_dp,
)
from vllm.v1.worker.gpu.input_batch import ( from vllm.v1.worker.gpu.input_batch import (
InputBatch, InputBatch,
InputBuffers, InputBuffers,
...@@ -265,7 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -265,7 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prepare_communication_buffer_for_model(self.model) prepare_communication_buffer_for_model(self.model)
if self.speculator is not None: if self.speculator is not None:
prepare_communication_buffer_for_model(self.speculator) prepare_communication_buffer_for_model(self.speculator.model)
# Initialize the components that require the model. # Initialize the components that require the model.
self.model_state = init_model_state( self.model_state = init_model_state(
...@@ -382,8 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -382,8 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None return None, None
assert self.execute_model_state is not None assert self.execute_model_state is not None
input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state (
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
self.execute_model_state = None self.execute_model_state = None
# dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None:
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
last_hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
num_rejected=torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
last_sampled=self.req_states.last_sampled_tokens,
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
)
assert hidden_states is not None # Last PP rank always has hidden_states assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states return hidden_states, sample_hidden_states
...@@ -431,17 +461,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -431,17 +461,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
self._dummy_pooler_run(hidden_states) self._dummy_pooler_run(hidden_states)
if self.speculator is not None:
num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens
)
self.speculator.run_model(
self.max_num_tokens,
attn_metadata=None,
slot_mappings=None,
num_tokens_across_dp=num_tokens_across_dp,
)
torch.cuda.synchronize() torch.cuda.synchronize()
del hidden_states, sample_hidden_states del hidden_states, sample_hidden_states
gc.collect() gc.collect()
...@@ -979,6 +998,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -979,6 +998,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, hidden_states,
aux_hidden_states, aux_hidden_states,
kv_connector_output, kv_connector_output,
num_tokens_across_dp,
) )
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
...@@ -1005,6 +1025,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1005,6 +1025,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, hidden_states,
aux_hidden_states, aux_hidden_states,
kv_connector_output, kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state ) = self.execute_model_state
self.execute_model_state = None self.execute_model_state = None
...@@ -1078,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1078,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens, self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu, self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
) )
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens) self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
......
...@@ -55,6 +55,26 @@ class EagleCudaGraphManager: ...@@ -55,6 +55,26 @@ class EagleCudaGraphManager:
def get_cudagraph_size(self, num_tokens: int) -> int | None: def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens) return self.cudagraph_sizes.get(num_tokens)
def get_cudagraph_runtime_mode(
self, num_tokens: int
) -> tuple[CUDAGraphMode, int | None]:
cudagraph_size = self.get_cudagraph_size(num_tokens)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.cudagraph_mode
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def capture_graph( def capture_graph(
self, self,
num_tokens: int, num_tokens: int,
......
...@@ -16,6 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -16,6 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer, build_slot_mappings_by_layer,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
...@@ -48,6 +49,10 @@ class EagleSpeculator: ...@@ -48,6 +49,10 @@ class EagleSpeculator:
self.vocab_size = self.draft_model_config.get_vocab_size() self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
# DP configuration
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.input_buffers = InputBuffers( self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens, max_num_tokens=self.max_num_tokens,
...@@ -122,8 +127,8 @@ class EagleSpeculator: ...@@ -122,8 +127,8 @@ class EagleSpeculator:
self, self,
num_reqs: int, num_reqs: int,
num_tokens_padded: int, num_tokens_padded: int,
attn_metadata: dict[str, Any], attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor], slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> None: ) -> None:
...@@ -164,9 +169,10 @@ class EagleSpeculator: ...@@ -164,9 +169,10 @@ class EagleSpeculator:
self.hidden_states, self.hidden_states,
self.max_model_len, self.max_model_len,
) )
self.block_tables.compute_slot_mappings( if attn_metadata is not None:
idx_mapping, query_start_loc, pos self.block_tables.compute_slot_mappings(
) idx_mapping, query_start_loc, pos
)
def capture_model(self) -> None: def capture_model(self) -> None:
if self.num_speculative_steps == 1: if self.num_speculative_steps == 1:
...@@ -203,6 +209,9 @@ class EagleSpeculator: ...@@ -203,6 +209,9 @@ class EagleSpeculator:
temperature: torch.Tensor, temperature: torch.Tensor,
# [max_num_reqs] # [max_num_reqs]
seeds: torch.Tensor, seeds: torch.Tensor,
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and # number of rejected tokens, we maintain the size of eagle's input_ids and
...@@ -236,7 +245,7 @@ class EagleSpeculator: ...@@ -236,7 +245,7 @@ class EagleSpeculator:
num_tokens, num_tokens,
attn_metadata, attn_metadata,
slot_mappings, slot_mappings,
num_tokens_across_dp=None, # FIXME num_tokens_across_dp=num_tokens_across_dp,
) )
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
...@@ -282,48 +291,64 @@ class EagleSpeculator: ...@@ -282,48 +291,64 @@ class EagleSpeculator:
self.max_model_len, self.max_model_len,
self.max_num_reqs, self.max_num_reqs,
) )
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) if not (dummy_run and skip_attn_for_dummy_run):
cudagraph_mode = self.cudagraph_manager.cudagraph_mode query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL: slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_mode, cudagraph_size = (
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
)
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding(
num_reqs,
cudagraph_size,
cudagraph_mode.value,
self.dp_size,
self.dp_rank,
)
)
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
if cudagraph_mode == CUDAGraphMode.FULL:
# Run full CUDA graph. # Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(cudagraph_size) self.cudagraph_manager.run_fullgraph(num_tokens_padded)
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
# Run eager or piecewise CUDA graph. # Run eager or piecewise CUDA graph.
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs attn_metadata_updated = None
query_start_loc_cpu = torch.arange( slot_mappings_updated = None
num_reqs + 1, dtype=torch.int32, device="cpu" if not (dummy_run and skip_attn_for_dummy_run):
) query_start_loc_cpu = torch.arange(
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] num_reqs + 1, dtype=torch.int32, device="cpu"
)
# FIXME(woosuk): This is UNSAFE!! block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups, # FIXME(woosuk): This is UNSAFE!!
num_reqs=num_reqs, attn_metadata_updated = build_attn_metadata(
num_tokens=num_reqs, attn_groups=self.attn_groups,
query_start_loc_gpu=query_start_loc, num_reqs=num_reqs,
query_start_loc_cpu=query_start_loc_cpu, num_tokens=num_reqs,
max_query_len=1, query_start_loc_gpu=query_start_loc,
seq_lens=self.input_buffers.seq_lens[:num_reqs], query_start_loc_cpu=query_start_loc_cpu,
max_seq_len=self.max_model_len, max_query_len=1,
block_tables=block_tables, seq_lens=self.input_buffers.seq_lens[:num_reqs],
slot_mappings=slot_mappings, max_seq_len=self.max_model_len,
kv_cache_config=self.kv_cache_config, block_tables=block_tables,
) slot_mappings=slot_mappings,
slot_mappings_by_layer = build_slot_mappings_by_layer( kv_cache_config=self.kv_cache_config,
slot_mappings, self.kv_cache_config )
) slot_mappings_updated = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
self.generate_draft( self.generate_draft(
num_reqs, num_reqs,
num_tokens_padded, num_tokens_padded,
attn_metadata, attn_metadata_updated,
slot_mappings_by_layer, slot_mappings_updated,
num_tokens_across_dp=None, # FIXME num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode, cudagraph_runtime_mode=cudagraph_mode,
) )
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment