Unverified Commit da3222f3 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Implement multi-step Eagle with CUDA graph (#29559)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 43c57925
...@@ -233,10 +233,11 @@ def prepare_inputs_to_capture( ...@@ -233,10 +233,11 @@ def prepare_inputs_to_capture(
query_start_loc.np[num_reqs:] = num_tokens query_start_loc.np[num_reqs:] = num_tokens
query_start_loc.copy_to_gpu() query_start_loc.copy_to_gpu()
seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32) seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32)
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len) # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and # rather than max_model_len. This introduces a discrepancy between
# seq_lens_np (CPU), which might cause issues in some attention backends. # seq_lens (on GPU) and seq_lens_np (on CPU), which may cause issues for
input_buffers.seq_lens[:num_reqs] = 1 # certain attention backends.
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0 input_buffers.seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
......
...@@ -140,10 +140,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -140,10 +140,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# CUDA graphs. # CUDA graphs.
self.cudagraph_manager = CudaGraphManager( self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
vllm_config=self.vllm_config,
device=self.device,
)
def get_supported_tasks(self) -> tuple[str]: def get_supported_tasks(self) -> tuple[str]:
return ("generate",) return ("generate",)
...@@ -203,6 +200,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -203,6 +200,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config, self.vllm_config,
self.device, self.device,
) )
if self.do_spec_decode:
# HACK(woosuk)
self.speculator.set_attn(
self.kv_cache_config,
self.attn_metadata_builders,
self.block_tables,
)
# TODO(woosuk): Support other backends. # TODO(woosuk): Support other backends.
if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()): if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()):
raise NotImplementedError("Only FLASH_ATTN backend is supported currently.") raise NotImplementedError("Only FLASH_ATTN backend is supported currently.")
...@@ -297,35 +302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -297,35 +302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits = self.model.compute_logits(hidden_states) logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata) self.sampler(logits, sampling_metadata)
@torch.inference_mode()
def _dummy_speculator_run(
self,
hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
) -> None:
num_tokens = hidden_states.shape[0]
num_reqs = min(num_tokens, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens,
input_buffers=self.input_buffers,
device=self.device,
)
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
self.propose_draft(
input_batch=input_batch,
sampling_metadata=sampling_metadata,
last_hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=num_sampled,
num_rejected=num_rejected,
)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
hidden_states, sample_hidden_states = self._dummy_run( hidden_states, sample_hidden_states = self._dummy_run(
...@@ -334,7 +310,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -334,7 +310,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
self._dummy_sampler_run(sample_hidden_states) self._dummy_sampler_run(sample_hidden_states)
if self.do_spec_decode: if self.do_spec_decode:
self._dummy_speculator_run(hidden_states, None) num_tokens_across_dp = make_num_tokens_across_dp(
self.dp_size, self.max_num_tokens
)
self.speculator.run_model(
self.max_num_tokens,
attn_metadata=None,
num_tokens_across_dp=num_tokens_across_dp,
)
torch.cuda.synchronize() torch.cuda.synchronize()
del hidden_states, sample_hidden_states del hidden_states, sample_hidden_states
gc.collect() gc.collect()
...@@ -368,6 +351,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -368,6 +351,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builders=self.attn_metadata_builders, attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
) )
if self.do_spec_decode:
self.speculator.capture_model()
end_time = time.perf_counter() end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sampler import gumbel_sample from vllm.v1.worker.gpu.sampler import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.states import SamplingMetadata from vllm.v1.worker.gpu.states import SamplingMetadata
logger = init_logger(__name__)
class EagleSpeculator: class EagleSpeculator:
def __init__(self, vllm_config: VllmConfig, device: torch.device): def __init__(self, vllm_config: VllmConfig, device: torch.device):
...@@ -27,14 +39,49 @@ class EagleSpeculator: ...@@ -27,14 +39,49 @@ class EagleSpeculator:
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = vllm_config.model_config.max_model_len
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.pin_memory = is_pin_memory_available()
self.dtype = vllm_config.model_config.dtype
self.input_ids = torch.zeros( self.input_buffers = InputBuffers(
self.max_num_tokens, dtype=torch.int32, device=device max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
dtype=self.dtype,
device=device,
pin_memory=self.pin_memory,
)
self.hidden_states = torch.zeros(
self.max_num_tokens,
self.hidden_size,
dtype=self.dtype,
device=device,
)
self.temperature = torch.zeros(
self.max_num_reqs,
dtype=torch.float32,
device=device,
) )
self.positions = torch.zeros( self.seeds = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device self.max_num_reqs,
dtype=torch.int64,
device=device,
)
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
) )
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag from vllm.compilation.backends import set_model_tag
...@@ -49,6 +96,91 @@ class EagleSpeculator: ...@@ -49,6 +96,91 @@ class EagleSpeculator:
del self.model.lm_head del self.model.lm_head
self.model.lm_head = target_model.lm_head self.model.lm_head = target_model.lm_head
def set_attn(
self,
kv_cache_config: KVCacheConfig,
attn_metadata_builders: list[AttentionMetadataBuilder],
block_tables: BlockTables,
) -> None:
self.kv_cache_config = kv_cache_config
self.attn_metadata_builders = attn_metadata_builders
self.block_tables = block_tables
@torch.inference_mode()
def run_model(
self,
num_tokens: int,
attn_metadata: dict[str, Any],
num_tokens_across_dp: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids.gpu[:num_tokens],
positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
return last_hidden_states, hidden_states
def generate_draft(
self,
num_reqs: int,
attn_metadata: dict[str, Any],
num_tokens_across_dp: torch.Tensor | None,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
for step in range(1, self.num_speculative_steps):
# Run the eagle model.
last_hidden_states, hidden_states = self.run_model(
num_reqs, attn_metadata, num_tokens_across_dp
)
logits = self.model.compute_logits(last_hidden_states)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits,
self.temperature[:num_reqs],
self.seeds[:num_reqs],
pos + 1,
apply_temperature=True,
)
self.draft_tokens[:num_reqs, step] = draft_tokens
if step < self.num_speculative_steps - 1:
# Update the inputs for the next step.
update_eagle_inputs(
draft_tokens,
hidden_states,
self.input_buffers,
self.hidden_states,
self.max_model_len,
)
self.block_tables.compute_slot_mappings(query_start_loc, pos)
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
return
logger.info("Capturing model for Eagle speculator...")
self.cudagraph_manager.capture(
self.generate_draft,
self.input_buffers,
self.block_tables,
self.attn_metadata_builders,
self.kv_cache_config,
)
@torch.inference_mode() @torch.inference_mode()
def propose( def propose(
self, self,
...@@ -80,64 +212,110 @@ class EagleSpeculator: ...@@ -80,64 +212,110 @@ class EagleSpeculator:
) )
else: else:
hidden_states = last_hidden_states hidden_states = last_hidden_states
num_tokens = input_batch.num_tokens_after_padding
self.hidden_states[:num_tokens] = hidden_states
# Get the input ids and last token indices for the speculator. # Get the input ids and last token indices for the speculator.
last_token_indices = prepare_eagle_inputs( last_token_indices = prepare_eagle_inputs(
self.input_ids, self.input_buffers,
input_batch, input_batch,
num_sampled, num_sampled,
num_rejected, num_rejected,
last_sampled, last_sampled,
next_prefill_tokens, next_prefill_tokens,
) )
input_ids = self.input_ids[: input_batch.num_tokens_after_padding]
# Prefill: Run the eagle speculator with eager mode. # Prefill: Run the eagle speculator with eager mode.
with set_forward_context( # TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states, hidden_states = self.run_model(
num_tokens,
input_batch.attn_metadata, input_batch.attn_metadata,
self.vllm_config, num_tokens_across_dp=None, # FIXME
num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=input_batch.positions,
hidden_states=hidden_states,
) )
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs num_reqs = input_batch.num_reqs
cu_num_logits = input_batch.cu_num_logits[:num_reqs] cu_num_logits = input_batch.cu_num_logits[:num_reqs]
temperature = sampling_metadata.temperature[cu_num_logits]
seed = sampling_metadata.seeds[cu_num_logits]
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
pos = input_batch.positions[last_token_indices] + 1
# NOTE(woosuk): For draft sampling, we only consider the temperature # NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p, # and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance. # for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not # While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling. # affect the output distribution after rejection sampling.
temperature = self.temperature[:num_reqs]
seeds = self.seeds[:num_reqs]
pos = self.input_buffers.positions[:num_reqs]
# Gather the values and copy them to the pre-allocated buffers.
torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature)
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample( draft_tokens = gumbel_sample(
logits, temperature, seed, pos, apply_temperature=True logits, temperature, seeds, pos + 1, apply_temperature=True
) )
if self.num_speculative_steps == 1: if self.num_speculative_steps == 1:
# Early exit. # Early exit.
return draft_tokens.view(-1, 1) return draft_tokens.view(-1, 1)
raise NotImplementedError("num_speculative_steps > 1 is not supported yet.")
# Save the draft tokens for the first step.
self.draft_tokens[:num_reqs, 0] = draft_tokens
# Prepare the inputs for the decode steps.
prepare_eagle_decode(
draft_tokens,
hidden_states,
last_token_indices,
input_batch.seq_lens,
num_rejected,
self.input_buffers,
self.hidden_states,
self.max_model_len,
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None:
# Run CUDA graph.
self.cudagraph_manager.run(cudagraph_size)
return self.draft_tokens[:num_reqs]
# Run eager mode.
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1)
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
# HACK(woosuk)
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=None, # FIXME
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME
return self.draft_tokens[:num_reqs]
@triton.jit @triton.jit
def _prepare_eagle_inputs_kernel( def _prepare_eagle_inputs_kernel(
last_token_indices_ptr, last_token_indices_ptr,
eagle_input_ids_ptr, eagle_input_ids_ptr,
eagle_positions_ptr,
target_input_ids_ptr, target_input_ids_ptr,
target_positions_ptr,
idx_mapping_ptr, idx_mapping_ptr,
last_sampled_ptr, last_sampled_ptr,
next_prefill_tokens_ptr, next_prefill_tokens_ptr,
...@@ -175,9 +353,16 @@ def _prepare_eagle_inputs_kernel( ...@@ -175,9 +353,16 @@ def _prepare_eagle_inputs_kernel(
tl.store(last_token_indices_ptr + batch_idx, last_token_index) tl.store(last_token_indices_ptr + batch_idx, last_token_index)
tl.store(eagle_input_ids_ptr + last_token_index, next_token) tl.store(eagle_input_ids_ptr + last_token_index, next_token)
# Copy positions.
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask)
tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask)
def prepare_eagle_inputs( def prepare_eagle_inputs(
eagle_input_ids: torch.Tensor, input_buffers: InputBuffers,
input_batch: InputBatch, input_batch: InputBatch,
# [num_reqs] # [num_reqs]
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
...@@ -192,12 +377,14 @@ def prepare_eagle_inputs( ...@@ -192,12 +377,14 @@ def prepare_eagle_inputs(
last_token_indices = torch.empty( last_token_indices = torch.empty(
num_reqs, num_reqs,
dtype=torch.int64, dtype=torch.int64,
device=eagle_input_ids.device, device=num_sampled.device,
) )
_prepare_eagle_inputs_kernel[(num_reqs,)]( _prepare_eagle_inputs_kernel[(num_reqs,)](
last_token_indices, last_token_indices,
eagle_input_ids, input_buffers.input_ids.gpu,
input_buffers.positions,
input_batch.input_ids, input_batch.input_ids,
input_batch.positions,
input_batch.idx_mapping, input_batch.idx_mapping,
last_sampled, last_sampled,
next_prefill_tokens, next_prefill_tokens,
...@@ -207,3 +394,174 @@ def prepare_eagle_inputs( ...@@ -207,3 +394,174 @@ def prepare_eagle_inputs(
BLOCK_SIZE=1024, BLOCK_SIZE=1024,
) )
return last_token_indices return last_token_indices
@triton.jit
def _prepare_eagle_docode_kernel(
draft_tokens_ptr,
output_hidden_states_ptr,
output_hidden_states_stride,
last_token_indices_ptr,
target_seq_lens_ptr,
num_rejected_ptr,
input_ids_ptr,
positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
query_start_loc_ptr,
seq_lens_ptr,
hidden_size,
max_model_len,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_idx == num_reqs:
# Compute query_start_loc. Pad it with the last query_start_loc
# for CUDA graphs.
for i in range(0, max_num_reqs + 1, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
q = tl.where(block < num_reqs, block, num_reqs)
mask = block < max_num_reqs + 1
tl.store(query_start_loc_ptr + block, q, mask=mask)
# Pad seq_lens for CUDA graphs.
for i in range(req_idx, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
# draft token -> input id.
draft_token = tl.load(draft_tokens_ptr + req_idx)
tl.store(input_ids_ptr + req_idx, draft_token)
# output hidden states -> input hidden states.
src_idx = tl.load(last_token_indices_ptr + req_idx)
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + src_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Compute position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position = tl.load(positions_ptr + req_idx)
position = tl.minimum(position + 1, max_model_len - 1)
tl.store(positions_ptr + req_idx, position)
target_seq_len = tl.load(target_seq_lens_ptr + req_idx)
num_rejected = tl.load(num_rejected_ptr + req_idx)
seq_len = target_seq_len - num_rejected
seq_len = tl.minimum(seq_len + 1, max_model_len)
tl.store(seq_lens_ptr + req_idx, seq_len)
def prepare_eagle_decode(
draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
last_token_indices: torch.Tensor,
target_seq_lens: torch.Tensor,
num_rejected: torch.Tensor,
input_buffers: InputBuffers,
input_hidden_states: torch.Tensor,
max_model_len: int,
max_num_reqs: int,
):
num_reqs = draft_tokens.shape[0]
hidden_size = output_hidden_states.shape[-1]
_prepare_eagle_docode_kernel[(num_reqs + 1,)](
draft_tokens,
output_hidden_states,
output_hidden_states.stride(0),
last_token_indices,
target_seq_lens,
num_rejected,
input_buffers.input_ids.gpu,
input_buffers.positions,
input_hidden_states,
input_hidden_states.stride(0),
input_buffers.query_start_loc.gpu,
input_buffers.seq_lens,
hidden_size,
max_model_len,
max_num_reqs,
BLOCK_SIZE=1024,
)
@triton.jit
def _update_eagle_inputs_kernel(
input_ids_ptr,
positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
seq_lens_ptr,
max_model_len,
draft_tokens_ptr,
output_hidden_states_ptr,
output_hidden_states_stride,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
# Draft token -> Input ID.
draft_token = tl.load(draft_tokens_ptr + req_idx)
tl.store(input_ids_ptr + req_idx, draft_token)
# Output hidden states -> Input hidden states.
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + req_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Increment position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position = tl.load(positions_ptr + req_idx)
position = tl.minimum(position + 1, max_model_len - 1)
tl.store(positions_ptr + req_idx, position)
seq_len = tl.load(seq_lens_ptr + req_idx)
seq_len = tl.minimum(seq_len + 1, max_model_len)
tl.store(seq_lens_ptr + req_idx, seq_len)
def update_eagle_inputs(
draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
input_buffers: InputBuffers,
hidden_states: torch.Tensor,
max_model_len: int,
):
num_reqs, hidden_size = output_hidden_states.shape
_update_eagle_inputs_kernel[(num_reqs,)](
input_buffers.input_ids.gpu,
input_buffers.positions,
hidden_states,
hidden_states.stride(0),
input_buffers.seq_lens,
max_model_len,
draft_tokens,
output_hidden_states,
output_hidden_states.stride(0),
hidden_size,
BLOCK_SIZE=1024,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
capture_graphs,
get_cudagraph_sizes,
prepare_inputs_to_capture,
)
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
class EagleCudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
if self.compilation_config.cudagraph_mode is None:
self.cudagraph_mode = CUDAGraphMode.NONE
else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.cudagraph_mode == CUDAGraphMode.FULL:
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)
def capture_graph(
self,
num_tokens: int,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_metadata_builders,
self.max_model_len,
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool):
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
self.graphs[num_tokens] = graph
@torch.inference_mode()
def capture(
self,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
generate_fn=generate_fn,
input_buffers=input_buffers,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,
)
def run(self, num_tokens: int) -> None:
assert num_tokens in self.graphs
self.graphs[num_tokens].replay()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment