Unverified Commit bad7c26f authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[PP] Fix init_memory_pool desync & add PP for mixtral (#6223)

parent 12319a67
...@@ -229,6 +229,18 @@ jobs: ...@@ -229,6 +229,18 @@ jobs:
cd test/srt cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
- name: Benchmark offline decode throughput (PP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_offline_throughput_default_decode
- name: Benchmark offline prefill throughput (PP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_long_context_prefill
accuracy-test-1-gpu: accuracy-test-1-gpu:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false github.event.pull_request.draft == false
......
...@@ -468,9 +468,6 @@ class PrefillAdder: ...@@ -468,9 +468,6 @@ class PrefillAdder:
return AddReqResult.OTHER return AddReqResult.OTHER
with self._lock_node(req.last_node): with self._lock_node(req.last_node):
if total_tokens > self.rem_total_tokens:
return AddReqResult.NO_TOKEN
if ( if (
enable_hierarchical_cache enable_hierarchical_cache
and req.last_node_global is not None and req.last_node_global is not None
......
...@@ -719,7 +719,7 @@ class Scheduler( ...@@ -719,7 +719,7 @@ class Scheduler(
server_is_idle = False server_is_idle = False
result = self.run_batch(self.cur_batch) result = self.run_batch(self.cur_batch)
# send the outputs to the next step # (last rank) send the outputs to the next step
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
if self.cur_batch: if self.cur_batch:
next_token_ids, bids[mb_id] = ( next_token_ids, bids[mb_id] = (
...@@ -759,18 +759,18 @@ class Scheduler( ...@@ -759,18 +759,18 @@ class Scheduler(
self.process_batch_result(mbs[next_mb_id], output_result) self.process_batch_result(mbs[next_mb_id], output_result)
last_mbs[next_mb_id] = mbs[next_mb_id] last_mbs[next_mb_id] = mbs[next_mb_id]
# carry the outputs to the next stage # (not last rank)
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
if self.cur_batch: if self.cur_batch:
bids[mb_id] = result.bid bids[mb_id] = result.bid
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
if pp_outputs: if pp_outputs:
# send the outputs from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict( self.pp_group.send_tensor_dict(
pp_outputs.tensors, pp_outputs.tensors,
all_gather_group=self.attn_tp_group, all_gather_group=self.attn_tp_group,
) )
if not self.pp_group.is_last_rank:
# send out reqs to the next stage # send out reqs to the next stage
dp_offset = self.dp_rank * self.attn_tp_size dp_offset = self.dp_rank * self.attn_tp_size
if self.attn_tp_rank == 0: if self.attn_tp_rank == 0:
......
...@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig ...@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tp_group, get_tp_group,
get_world_group,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
set_custom_all_reduce, set_custom_all_reduce,
...@@ -404,7 +405,10 @@ class ModelRunner: ...@@ -404,7 +405,10 @@ class ModelRunner:
) )
min_per_gpu_memory = get_available_gpu_memory( min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1 self.device,
self.gpu_id,
distributed=get_world_group().world_size > 1,
cpu_group=get_world_group().cpu_group,
) )
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
self.attention_tp_group = get_attention_tp_group() self.attention_tp_group = get_attention_tp_group()
...@@ -716,7 +720,10 @@ class ModelRunner: ...@@ -716,7 +720,10 @@ class ModelRunner:
def profile_max_num_token(self, total_gpu_memory: int): def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1 self.device,
self.gpu_id,
distributed=get_world_group().world_size > 1,
cpu_group=get_world_group().cpu_group,
) )
if self.use_mla_backend: if self.use_mla_backend:
num_layers = ( num_layers = (
......
...@@ -16,13 +16,15 @@ ...@@ -16,13 +16,15 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, Optional, Tuple import logging
from typing import Iterable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
...@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE ...@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix, make_layers
logger = logging.getLogger(__name__)
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
...@@ -257,24 +262,32 @@ class MixtralModel(nn.Module): ...@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
self.embed_tokens = VocabParallelEmbedding( if self.pp_group.is_first_rank:
config.vocab_size, self.embed_tokens = VocabParallelEmbedding(
config.hidden_size, config.vocab_size,
prefix=add_prefix("embed_tokens", prefix), config.hidden_size,
) prefix=add_prefix("embed_tokens", prefix),
self.layers = nn.ModuleList( )
[ else:
MixtralDecoderLayer( self.embed_tokens = PPMissingLayer()
config,
i, self.layers, self.start_layer, self.end_layer = make_layers(
quant_config=quant_config, config.num_hidden_layers,
prefix=add_prefix(f"layers.{i}", prefix), lambda idx, prefix: MixtralDecoderLayer(
) config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
for i in range(config.num_hidden_layers) ),
] pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix="layers",
return_tuple=True,
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
def forward( def forward(
self, self,
...@@ -282,18 +295,35 @@ class MixtralModel(nn.Module): ...@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: pp_proxy_tensors: Optional[PPProxyTensors] = None,
if input_embeds is None: ) -> Union[torch.Tensor, PPProxyTensors]:
hidden_states = self.embed_tokens(input_ids) if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else: else:
hidden_states = input_embeds assert pp_proxy_tensors is not None
residual = None hidden_states = pp_proxy_tensors["hidden_states"]
for i in range(len(self.layers)): residual = pp_proxy_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
) )
hidden_states, _ = self.norm(hidden_states, residual)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.pp_group = get_pp_group()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = MixtralModel( self.model = MixtralModel(
...@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module): ...@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(
return self.logits_processor( input_ids,
input_ids, hidden_states, self.lm_head, forward_batch positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
) )
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return hidden_states
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
...@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module): ...@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module): ...@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
if name is None: if name is None:
continue continue
param = params_dict[name] if name in params_dict.keys():
weight_loader = getattr( param = params_dict[name]
param, "weight_loader", default_weight_loader weight_loader = getattr(
) param, "weight_loader", default_weight_loader
weight_loader(param, loaded_weight) )
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
EntryClass = MixtralForCausalLM EntryClass = MixtralForCausalLM
...@@ -347,6 +347,12 @@ class ServerArgs: ...@@ -347,6 +347,12 @@ class ServerArgs:
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
) )
if self.pp_size > 1:
self.disable_overlap_schedule = True
logger.warning(
"Pipeline parallelism is incompatible with overlap schedule."
)
# Speculative Decoding # Speculative Decoding
if self.speculative_algorithm == "NEXTN": if self.speculative_algorithm == "NEXTN":
# NEXTN shares the same implementation of EAGLE # NEXTN shares the same implementation of EAGLE
......
...@@ -282,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0): ...@@ -282,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
return wrapper return wrapper
def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True): def get_available_gpu_memory(
device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
):
""" """
Get available memory for cuda:gpu_id device. Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs. When distributed is True, the available memory is the minimum available memory of all GPUs.
...@@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True ...@@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info() free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
if distributed: if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
torch.device(device, gpu_id) torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
) )
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
free_gpu_memory = tensor.item() free_gpu_memory = tensor.item()
return free_gpu_memory / (1 << 30) return free_gpu_memory / (1 << 30)
......
...@@ -272,6 +272,50 @@ class TestBenchServing(CustomTestCase): ...@@ -272,6 +272,50 @@ class TestBenchServing(CustomTestCase):
else: else:
self.assertGreater(res["output_throughput"], 2200) self.assertGreater(res["output_throughput"], 2200)
def test_pp_offline_throughput_default_decode(self):
res = run_bench_serving(
model=DEFAULT_MOE_MODEL_NAME_FOR_TEST,
num_prompts=1000,
request_rate=float("inf"),
random_input_len=1,
random_output_len=1024,
other_server_args=["--pp", "2"],
need_warmup=True,
seed=42,
)
if is_in_ci():
write_github_step_summary(
f"### test_pp_offline_throughput_default_decode\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
)
self.assertGreater(res["output_throughput"], 7500)
def test_pp_long_context_prefill(self):
res = run_bench_serving(
model="meta-llama/Llama-3.3-70B-Instruct",
num_prompts=4,
request_rate=float("inf"),
random_input_len=128000,
random_output_len=1,
dataset_name="random",
other_server_args=[
"--quantization",
"fp8",
"--pp",
2,
],
need_warmup=False,
seed=42,
)
if is_in_ci():
write_github_step_summary(
f"### test_pp_long_context_latency_prefill\n"
f'input_throughput: {res["input_throughput"]:.2f} ms\n'
)
self.assertGreater(res["input_throughput"], 4000)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment