Unverified Commit 570d3343 authored by Xiaoze Fan's avatar Xiaoze Fan Committed by GitHub
Browse files

[Feature] Layer-wise Prefill (#7634)


Signed-off-by: default avatarjason-fxz <jason341132@qq.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent d9eb5efc
...@@ -1328,6 +1328,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1328,6 +1328,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.model_config.vocab_size, self.model_config.vocab_size,
) )
def prepare_for_split_prefill(self):
self.prepare_for_extend()
# For split prefill, we need to set the forward mode to SPLIT_PREFILL
self.forward_mode = ForwardMode.SPLIT_PREFILL
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED self.forward_mode = ForwardMode.MIXED
running_bs = running_batch.batch_size() running_bs = running_batch.batch_size()
......
...@@ -68,6 +68,8 @@ class ForwardMode(IntEnum): ...@@ -68,6 +68,8 @@ class ForwardMode(IntEnum):
MIXED = auto() MIXED = auto()
# No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated. # No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated.
IDLE = auto() IDLE = auto()
# Split Prefill for PD multiplexing
SPLIT_PREFILL = auto()
# Used in speculative decoding: verify a batch in the target model. # Used in speculative decoding: verify a batch in the target model.
TARGET_VERIFY = auto() TARGET_VERIFY = auto()
...@@ -95,6 +97,9 @@ class ForwardMode(IntEnum): ...@@ -95,6 +97,9 @@ class ForwardMode(IntEnum):
def is_mixed(self): def is_mixed(self):
return self == ForwardMode.MIXED return self == ForwardMode.MIXED
def is_split_prefill(self):
return self == ForwardMode.SPLIT_PREFILL
def is_idle(self): def is_idle(self):
return self == ForwardMode.IDLE return self == ForwardMode.IDLE
...@@ -194,6 +199,14 @@ class ForwardBatch: ...@@ -194,6 +199,14 @@ class ForwardBatch:
extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For split prefill
# intermediate values for split prefill
hidden_states: torch.Tensor = None
residual: torch.Tensor = None
model_specific_states: Dict[str, any] = None
# current split index of layer
split_index: int = 0
# For MLA chunked prefix cache used in chunked prefill # For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether the kv cache needs to be attended in current pass # Tell attention backend whether the kv cache needs to be attended in current pass
attn_attend_prefix_cache: Optional[bool] = None attn_attend_prefix_cache: Optional[bool] = None
......
...@@ -1513,11 +1513,34 @@ class ModelRunner: ...@@ -1513,11 +1513,34 @@ class ModelRunner:
**kwargs, **kwargs,
) )
def forward_split_prefill(
self,
forward_batch: ForwardBatch,
reinit_attn_backend: bool = False,
forward_count: int = 1,
) -> LogitsProcessorOutput:
if forward_batch.split_index == 0 or reinit_attn_backend:
self.attn_backend.init_forward_metadata(forward_batch)
next_split_index = min(
forward_batch.split_index + forward_count,
self.model_config.num_hidden_layers,
)
ret = self.model.forward_split_prefill(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
(forward_batch.split_index, next_split_index),
)
forward_batch.split_index = next_split_index
return ret
def forward( def forward(
self, self,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False, skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None, pp_proxy_tensors: Optional[PPProxyTensors] = None,
reinit_attn_backend: bool = False,
split_forward_count: int = 1,
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
self.forward_pass_id += 1 self.forward_pass_id += 1
...@@ -1526,7 +1549,11 @@ class ModelRunner: ...@@ -1526,7 +1549,11 @@ class ModelRunner:
forward_batch, forward_batch,
): ):
output = self._forward_raw( output = self._forward_raw(
forward_batch, skip_attn_backend_init, pp_proxy_tensors forward_batch,
skip_attn_backend_init,
pp_proxy_tensors,
reinit_attn_backend,
split_forward_count,
) )
if self.eplb_manager is not None: if self.eplb_manager is not None:
...@@ -1539,6 +1566,8 @@ class ModelRunner: ...@@ -1539,6 +1566,8 @@ class ModelRunner:
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
skip_attn_backend_init: bool, skip_attn_backend_init: bool,
pp_proxy_tensors: Optional[PPProxyTensors], pp_proxy_tensors: Optional[PPProxyTensors],
reinit_attn_backend: bool = False,
split_forward_count: int = 1,
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
can_run_cuda_graph = bool( can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph() forward_batch.forward_mode.is_cuda_graph()
...@@ -1559,6 +1588,12 @@ class ModelRunner: ...@@ -1559,6 +1588,12 @@ class ModelRunner:
skip_attn_backend_init=skip_attn_backend_init, skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
elif forward_batch.forward_mode.is_split_prefill():
ret = self.forward_split_prefill(
forward_batch,
reinit_attn_backend=reinit_attn_backend,
forward_count=split_forward_count,
)
elif forward_batch.forward_mode.is_idle(): elif forward_batch.forward_mode.is_idle():
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors) ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
else: else:
......
...@@ -318,6 +318,54 @@ class GemmaForCausalLM(nn.Module): ...@@ -318,6 +318,54 @@ class GemmaForCausalLM(nn.Module):
input_ids, hidden_states, self.model.embed_tokens, forward_batch input_ids, hidden_states, self.model.embed_tokens, forward_batch
) )
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# Normalize the embedding by sqrt(hidden_size)
forward_batch.hidden_states *= self.model.config.hidden_size**0.5
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
forward_batch.hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
# logits process
result = self.logits_processor(
input_ids,
forward_batch.hidden_states,
self.model.embed_tokens,
forward_batch,
)
else:
result = None
return result
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
...@@ -381,6 +381,57 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -381,6 +381,57 @@ class Gemma2ForCausalLM(nn.Module):
input_ids, hidden_states, self.model.embed_tokens, forward_batch input_ids, hidden_states, self.model.embed_tokens, forward_batch
) )
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# Normalize
normalizer = torch.tensor(
self.model.config.hidden_size**0.5, dtype=torch.float16
)
forward_batch.hidden_states *= normalizer
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
forward_batch.hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
# logits process
result = self.logits_processor(
input_ids,
forward_batch.hidden_states,
self.model.embed_tokens,
forward_batch,
)
else:
result = None
return result
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
# return input_dim, output_dim # return input_dim, output_dim
if module_name in ["q_proj", "qkv_proj"]: if module_name in ["q_proj", "qkv_proj"]:
......
...@@ -647,6 +647,69 @@ class Gemma3ForCausalLM(PreTrainedModel): ...@@ -647,6 +647,69 @@ class Gemma3ForCausalLM(PreTrainedModel):
input_ids, hidden_states, self.model.embed_tokens, forward_batch input_ids, hidden_states, self.model.embed_tokens, forward_batch
) )
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
hidden_states = self.model.embed_tokens(input_ids)
else:
hidden_states = input_embeds
if positions.dim() == 1:
positions = einops.rearrange(positions, "s -> 1 s")
position_embeddings_global = self.model.rotary_emb(hidden_states, positions)
position_embeddings_local = self.model.rotary_emb_local(
hidden_states, positions
)
forward_batch.hidden_states = hidden_states
forward_batch.model_specific_states = {
"positions": positions,
"position_embeddings_global": position_embeddings_global,
"position_embeddings_local": position_embeddings_local,
}
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
layer_output = layer(
positions=forward_batch.model_specific_states["positions"],
position_embeddings_global=forward_batch.model_specific_states[
"position_embeddings_global"
],
position_embeddings_local=forward_batch.model_specific_states[
"position_embeddings_local"
],
hidden_states=forward_batch.hidden_states,
forward_batch=forward_batch,
)
forward_batch.hidden_states = layer_output[0]
if end == self.model.config.num_hidden_layers:
# norm
forward_batch.hidden_states = self.model.norm(forward_batch.hidden_states)
# logits process
result = self.logits_processor(
input_ids,
forward_batch.hidden_states,
self.model.embed_tokens,
forward_batch,
)
else:
result = None
return result
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
...@@ -480,6 +480,47 @@ class LlamaForCausalLM(nn.Module): ...@@ -480,6 +480,47 @@ class LlamaForCausalLM(nn.Module):
else: else:
return hidden_states return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
) -> Optional[LogitsProcessorOutput]:
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property @property
def start_layer(self): def start_layer(self):
return self.model.start_layer return self.model.start_layer
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
import time
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
...@@ -286,6 +287,42 @@ class QWenLMHeadModel(nn.Module): ...@@ -286,6 +287,42 @@ class QWenLMHeadModel(nn.Module):
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
):
start, end = split_interval
# embed
if start == 0:
forward_batch.hidden_states = self.transformer.wte(input_ids)
# decoder layer
for i in range(start, end):
layer = self.transformer.h[i]
forward_batch.hidden_states = layer(
positions,
forward_batch.hidden_states,
forward_batch,
)
if end == self.transformer.config.num_hidden_layers:
# norm
forward_batch.hidden_states = self.transformer.ln_f(
forward_batch.hidden_states
)
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
...@@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module):
else: else:
return hidden_states return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property @property
def start_layer(self): def start_layer(self):
return self.model.start_layer return self.model.start_layer
......
...@@ -406,6 +406,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -406,6 +406,7 @@ class Qwen2MoeModel(nn.Module):
alt_stream: Optional[torch.cuda.Stream] = None, alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
...@@ -554,6 +555,49 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -554,6 +555,49 @@ class Qwen2MoeForCausalLM(nn.Module):
else: else:
return hidden_states return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property @property
def start_layer(self): def start_layer(self):
return self.model.start_layer return self.model.start_layer
......
# Adapted from qwen2.py # Adapted from qwen2.py
import logging import logging
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
...@@ -367,6 +366,47 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -367,6 +366,47 @@ class Qwen3ForCausalLM(nn.Module):
else: else:
return hidden_states return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property @property
def start_layer(self): def start_layer(self):
return self.model.start_layer return self.model.start_layer
......
...@@ -745,6 +745,49 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -745,6 +745,49 @@ class Qwen3MoeForCausalLM(nn.Module):
else: else:
return hidden_states return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property @property
def start_layer(self): def start_layer(self):
return self.model.start_layer return self.model.start_layer
......
...@@ -500,6 +500,7 @@ class TboForwardBatchPreparer: ...@@ -500,6 +500,7 @@ class TboForwardBatchPreparer:
"capture_hidden_mode", "capture_hidden_mode",
"padded_static_len", "padded_static_len",
"mrope_positions", # only used by qwen2-vl, thus not care "mrope_positions", # only used by qwen2-vl, thus not care
"split_index", # for split prefill
]: ]:
output_dict[key] = getattr(batch, key) output_dict[key] = getattr(batch, key)
if not batch.forward_mode.is_target_verify(): if not batch.forward_mode.is_target_verify():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment