Unverified Commit 2f8b4ce0 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Do not initialize sampler for non-last PP ranks (#36824)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 2ef69456
...@@ -438,17 +438,20 @@ def _post_update_kernel( ...@@ -438,17 +438,20 @@ def _post_update_kernel(
for i in range(num_sampled): for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i) token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
token_ptr = (
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
)
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
tl.store( tl.store(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i, all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
token_id, token_id,
) )
if output_bin_counts_ptr is not None:
token_ptr = (
output_bin_counts_ptr
+ req_state_idx * output_bin_counts_stride
+ token_id
)
count = tl.load(token_ptr)
tl.store(token_ptr, count + 1)
query_start = tl.load(query_start_loc_ptr + req_id) query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1) query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start query_len = query_end - query_start
...@@ -467,7 +470,7 @@ def post_update( ...@@ -467,7 +470,7 @@ def post_update(
# [max_num_reqs] # [max_num_reqs]
last_sampled_tokens: torch.Tensor, last_sampled_tokens: torch.Tensor,
# [max_num_reqs, vocab_size] # [max_num_reqs, vocab_size]
output_bin_counts: torch.Tensor, output_bin_counts: torch.Tensor | None,
# [num_reqs, num_speculative_steps + 1] # [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor, sampled_tokens: torch.Tensor,
# [num_reqs] # [num_reqs]
...@@ -487,7 +490,7 @@ def post_update( ...@@ -487,7 +490,7 @@ def post_update(
num_computed_tokens, num_computed_tokens,
last_sampled_tokens, last_sampled_tokens,
output_bin_counts, output_bin_counts,
output_bin_counts.stride(0), output_bin_counts.stride(0) if output_bin_counts is not None else 0,
sampled_tokens, sampled_tokens,
sampled_tokens.stride(0), sampled_tokens.stride(0),
num_sampled, num_sampled,
......
...@@ -183,6 +183,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -183,6 +183,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Draft tokens propagation - for spec-dec + struct outputs. # Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device) self.draft_tokens_handler = DraftTokensHandler(self.device)
# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None
# General request states. # General request states.
self.req_states = RequestState( self.req_states = RequestState(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
...@@ -199,6 +203,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -199,6 +203,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_tokens=self.max_num_tokens, max_num_tokens=self.max_num_tokens,
device=self.device, device=self.device,
) )
self.sampler: Sampler | None = None
self.rejection_sampler: RejectionSampler | None = None
self.prompt_logprobs_worker: PromptLogprobsWorker | None = None
self.structured_outputs_worker: StructuredOutputsWorker | None = None
if self.is_last_pp_rank and not self.is_pooling_model:
# Initialize sampling-related workers.
# These components are only set up on the last PP rank and
# for generative (non-pooling) models.
self.sampler = Sampler( self.sampler = Sampler(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
...@@ -213,6 +226,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -213,6 +226,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_strict_rejection_sampling=use_strict_rejection_sampling, use_strict_rejection_sampling=use_strict_rejection_sampling,
) )
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs) self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
device=self.device,
)
# CUDA graphs. # CUDA graphs.
self.decode_query_len = self.num_speculative_steps + 1 self.decode_query_len = self.num_speculative_steps + 1
...@@ -222,21 +240,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -222,21 +240,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.compilation_config.cudagraph_mode, self.compilation_config.cudagraph_mode,
decode_query_len=self.decode_query_len, decode_query_len=self.decode_query_len,
) )
# Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
device=self.device,
)
# LoRA-related workers. # LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs) self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# KV Connector if configured. # KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None
# For transferring state from execute_model to subsequent sample_tokens call. # For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: ExecuteModelState | None = None self.execute_model_state: ExecuteModelState | None = None
...@@ -248,8 +256,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -248,8 +256,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
tasks: list[SupportedTask] = [] tasks: list[SupportedTask] = []
if self.model_config.runner_type == "generate": if self.model_config.runner_type == "generate":
tasks.extend(self.model_state.get_supported_generation_tasks()) tasks.extend(self.model_state.get_supported_generation_tasks())
if self.pooling_runner is not None: if self.is_pooling_model:
tasks.extend(self.pooling_runner.get_supported_pooling_tasks()) # Do not rely on pooling_runner here, since this information is needed
# on the first PP rank, while pooling_runner is only initialized
# on the last PP rank.
tasks.extend(PoolingRunner.get_supported_tasks(self.model))
return tuple(tasks) return tuple(tasks)
def load_model(self, *args, **kwargs) -> None: def load_model(self, *args, **kwargs) -> None:
...@@ -289,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -289,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.model_state = init_model_state( self.model_state = init_model_state(
self.vllm_config, self.model, self.encoder_cache, self.device self.vllm_config, self.model, self.encoder_cache, self.device
) )
if self.is_pooling_model: if self.is_pooling_model and self.is_last_pp_rank:
self.pooling_runner = PoolingRunner(self.model) self.pooling_runner = PoolingRunner(self.model)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
...@@ -420,6 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -420,6 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# dummy run the eagle speculator's propose to ensure DP/EP sync. # dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None: if self.speculator is not None:
assert self.sampler is not None
self.speculator.propose( self.speculator.propose(
input_batch=input_batch, input_batch=input_batch,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
...@@ -457,10 +469,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -457,10 +469,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): During the initial memory profiling, the sampler may skip # NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible # top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution. # during actual execution.
self.sampler( assert self.sampler is not None
logits, self.sampler(logits, dummy_input_batch)
dummy_input_batch,
)
@torch.inference_mode() @torch.inference_mode()
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None: def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
...@@ -558,6 +568,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -558,6 +568,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.remove_request(req_id) self.req_states.remove_request(req_id)
if self.encoder_cache is not None: if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id) self.encoder_cache.remove_request(req_id)
if self.prompt_logprobs_worker is not None:
self.prompt_logprobs_worker.remove_request(req_id) self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id) self.lora_state.remove_request(req_id)
...@@ -589,18 +600,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -589,18 +600,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request) self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
if new_req_data.sampling_params is not None: if self.is_last_pp_rank and new_req_data.sampling_params is not None:
assert self.sampler is not None
self.sampler.add_request( self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params req_index, prompt_len, new_req_data.sampling_params
) )
assert self.prompt_logprobs_worker is not None
self.prompt_logprobs_worker.add_request( self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params req_id, req_index, new_req_data.sampling_params
) )
if scheduler_output.scheduled_new_reqs: if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes() self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes()
self.model_state.apply_staged_writes() self.model_state.apply_staged_writes()
if self.sampler is not None:
self.sampler.apply_staged_writes()
def update_requests(self, scheduler_output: SchedulerOutput) -> None: def update_requests(self, scheduler_output: SchedulerOutput) -> None:
# Add new blocks for the existing requests. # Add new blocks for the existing requests.
...@@ -788,6 +802,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -788,6 +802,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None: if grammar_output is not None:
# Apply grammar bitmask to the logits in-place. # Apply grammar bitmask to the logits in-place.
assert self.structured_outputs_worker is not None
self.structured_outputs_worker.apply_grammar_bitmask( self.structured_outputs_worker.apply_grammar_bitmask(
logits, logits,
input_batch, input_batch,
...@@ -797,12 +812,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -797,12 +812,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if input_batch.num_draft_tokens == 0: if input_batch.num_draft_tokens == 0:
# No draft tokens (common case). # No draft tokens (common case).
sampler_output = self.sampler( assert self.sampler is not None
logits, sampler_output = self.sampler(logits, input_batch)
input_batch,
)
else: else:
# Rejection sampling for spec decoding. # Rejection sampling for spec decoding.
assert self.rejection_sampler is not None
sampler_output = self.rejection_sampler( sampler_output = self.rejection_sampler(
logits, logits,
input_batch, input_batch,
...@@ -831,11 +845,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -831,11 +845,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
) -> None: ) -> None:
# Update the number of computed tokens. # Update the number of computed tokens.
if self.is_last_pp_rank:
assert self.sampler is not None
output_bin_counts = self.sampler.penalties_state.output_bin_counts
else:
output_bin_counts = None
post_update( post_update(
input_batch.idx_mapping, input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu, self.req_states.num_computed_tokens.gpu,
self.req_states.last_sampled_tokens, self.req_states.last_sampled_tokens,
self.sampler.penalties_state.output_bin_counts, output_bin_counts,
sampled_tokens, sampled_tokens,
num_sampled, num_sampled,
num_rejected, num_rejected,
...@@ -1076,6 +1095,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1076,6 +1095,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Broadcast to non-last PP ranks (handles spec decode multi-token). # Broadcast to non-last PP ranks (handles spec decode multi-token).
pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected) pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
assert self.prompt_logprobs_worker is not None
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs( prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
self.model.compute_logits, self.model.compute_logits,
hidden_states, hidden_states,
...@@ -1115,6 +1135,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1115,6 +1135,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
) )
if self.speculator is not None: if self.speculator is not None:
assert self.sampler is not None
draft_tokens = self.speculator.propose( draft_tokens = self.speculator.propose(
input_batch, input_batch,
attn_metadata, attn_metadata,
......
...@@ -19,10 +19,11 @@ class PoolingRunner: ...@@ -19,10 +19,11 @@ class PoolingRunner:
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module):
self.model = cast(VllmModelForPooling, model) self.model = cast(VllmModelForPooling, model)
def get_supported_pooling_tasks(self) -> list[PoolingTask]: @staticmethod
if not is_pooling_model(self.model): def get_supported_tasks(model: nn.Module) -> list[PoolingTask]:
if not is_pooling_model(model):
return [] return []
assert "embed" in self.model.pooler.get_supported_tasks() assert "embed" in model.pooler.get_supported_tasks()
return ["embed"] return ["embed"]
def pool( def pool(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment