Unverified Commit 570d3343 authored by Xiaoze Fan's avatar Xiaoze Fan Committed by GitHub
Browse files

[Feature] Layer-wise Prefill (#7634)


Signed-off-by: default avatarjason-fxz <jason341132@qq.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent d9eb5efc
......@@ -1328,6 +1328,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.model_config.vocab_size,
)
def prepare_for_split_prefill(self):
self.prepare_for_extend()
# For split prefill, we need to set the forward mode to SPLIT_PREFILL
self.forward_mode = ForwardMode.SPLIT_PREFILL
def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED
running_bs = running_batch.batch_size()
......
......@@ -68,6 +68,8 @@ class ForwardMode(IntEnum):
MIXED = auto()
# No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated.
IDLE = auto()
# Split Prefill for PD multiplexing
SPLIT_PREFILL = auto()
# Used in speculative decoding: verify a batch in the target model.
TARGET_VERIFY = auto()
......@@ -95,6 +97,9 @@ class ForwardMode(IntEnum):
def is_mixed(self):
return self == ForwardMode.MIXED
def is_split_prefill(self):
return self == ForwardMode.SPLIT_PREFILL
def is_idle(self):
return self == ForwardMode.IDLE
......@@ -194,6 +199,14 @@ class ForwardBatch:
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For split prefill
# intermediate values for split prefill
hidden_states: torch.Tensor = None
residual: torch.Tensor = None
model_specific_states: Dict[str, any] = None
# current split index of layer
split_index: int = 0
# For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether the kv cache needs to be attended in current pass
attn_attend_prefix_cache: Optional[bool] = None
......
......@@ -1513,11 +1513,34 @@ class ModelRunner:
**kwargs,
)
def forward_split_prefill(
self,
forward_batch: ForwardBatch,
reinit_attn_backend: bool = False,
forward_count: int = 1,
) -> LogitsProcessorOutput:
if forward_batch.split_index == 0 or reinit_attn_backend:
self.attn_backend.init_forward_metadata(forward_batch)
next_split_index = min(
forward_batch.split_index + forward_count,
self.model_config.num_hidden_layers,
)
ret = self.model.forward_split_prefill(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
(forward_batch.split_index, next_split_index),
)
forward_batch.split_index = next_split_index
return ret
def forward(
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
reinit_attn_backend: bool = False,
split_forward_count: int = 1,
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
self.forward_pass_id += 1
......@@ -1526,7 +1549,11 @@ class ModelRunner:
forward_batch,
):
output = self._forward_raw(
forward_batch, skip_attn_backend_init, pp_proxy_tensors
forward_batch,
skip_attn_backend_init,
pp_proxy_tensors,
reinit_attn_backend,
split_forward_count,
)
if self.eplb_manager is not None:
......@@ -1539,6 +1566,8 @@ class ModelRunner:
forward_batch: ForwardBatch,
skip_attn_backend_init: bool,
pp_proxy_tensors: Optional[PPProxyTensors],
reinit_attn_backend: bool = False,
split_forward_count: int = 1,
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph()
......@@ -1559,6 +1588,12 @@ class ModelRunner:
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
elif forward_batch.forward_mode.is_split_prefill():
ret = self.forward_split_prefill(
forward_batch,
reinit_attn_backend=reinit_attn_backend,
forward_count=split_forward_count,
)
elif forward_batch.forward_mode.is_idle():
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
else:
......
......@@ -318,6 +318,54 @@ class GemmaForCausalLM(nn.Module):
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# Normalize the embedding by sqrt(hidden_size)
forward_batch.hidden_states *= self.model.config.hidden_size**0.5
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
forward_batch.hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
# logits process
result = self.logits_processor(
input_ids,
forward_batch.hidden_states,
self.model.embed_tokens,
forward_batch,
)
else:
result = None
return result
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......
......@@ -381,6 +381,57 @@ class Gemma2ForCausalLM(nn.Module):
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# Normalize
normalizer = torch.tensor(
self.model.config.hidden_size**0.5, dtype=torch.float16
)
forward_batch.hidden_states *= normalizer
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
forward_batch.hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
# logits process
result = self.logits_processor(
input_ids,
forward_batch.hidden_states,
self.model.embed_tokens,
forward_batch,
)
else:
result = None
return result
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name in ["q_proj", "qkv_proj"]:
......
......@@ -647,6 +647,69 @@ class Gemma3ForCausalLM(PreTrainedModel):
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
hidden_states = self.model.embed_tokens(input_ids)
else:
hidden_states = input_embeds
if positions.dim() == 1:
positions = einops.rearrange(positions, "s -> 1 s")
position_embeddings_global = self.model.rotary_emb(hidden_states, positions)
position_embeddings_local = self.model.rotary_emb_local(
hidden_states, positions
)
forward_batch.hidden_states = hidden_states
forward_batch.model_specific_states = {
"positions": positions,
"position_embeddings_global": position_embeddings_global,
"position_embeddings_local": position_embeddings_local,
}
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
layer_output = layer(
positions=forward_batch.model_specific_states["positions"],
position_embeddings_global=forward_batch.model_specific_states[
"position_embeddings_global"
],
position_embeddings_local=forward_batch.model_specific_states[
"position_embeddings_local"
],
hidden_states=forward_batch.hidden_states,
forward_batch=forward_batch,
)
forward_batch.hidden_states = layer_output[0]
if end == self.model.config.num_hidden_layers:
# norm
forward_batch.hidden_states = self.model.norm(forward_batch.hidden_states)
# logits process
result = self.logits_processor(
input_ids,
forward_batch.hidden_states,
self.model.embed_tokens,
forward_batch,
)
else:
result = None
return result
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......
......@@ -480,6 +480,47 @@ class LlamaForCausalLM(nn.Module):
else:
return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
) -> Optional[LogitsProcessorOutput]:
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property
def start_layer(self):
return self.model.start_layer
......
......@@ -15,6 +15,7 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
import time
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
......@@ -286,6 +287,42 @@ class QWenLMHeadModel(nn.Module):
input_ids, hidden_states, self.lm_head, forward_batch
)
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
):
start, end = split_interval
# embed
if start == 0:
forward_batch.hidden_states = self.transformer.wte(input_ids)
# decoder layer
for i in range(start, end):
layer = self.transformer.h[i]
forward_batch.hidden_states = layer(
positions,
forward_batch.hidden_states,
forward_batch,
)
if end == self.transformer.config.num_hidden_layers:
# norm
forward_batch.hidden_states = self.transformer.ln_f(
forward_batch.hidden_states
)
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......
......@@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module):
else:
return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property
def start_layer(self):
return self.model.start_layer
......
......@@ -406,6 +406,7 @@ class Qwen2MoeModel(nn.Module):
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
......@@ -554,6 +555,49 @@ class Qwen2MoeForCausalLM(nn.Module):
else:
return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property
def start_layer(self):
return self.model.start_layer
......
# Adapted from qwen2.py
import logging
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple
......@@ -367,6 +366,47 @@ class Qwen3ForCausalLM(nn.Module):
else:
return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property
def start_layer(self):
return self.model.start_layer
......
......@@ -745,6 +745,49 @@ class Qwen3MoeForCausalLM(nn.Module):
else:
return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property
def start_layer(self):
return self.model.start_layer
......
......@@ -500,6 +500,7 @@ class TboForwardBatchPreparer:
"capture_hidden_mode",
"padded_static_len",
"mrope_positions", # only used by qwen2-vl, thus not care
"split_index", # for split prefill
]:
output_dict[key] = getattr(batch, key)
if not batch.forward_mode.is_target_verify():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment