Unverified Commit e9fd658a authored by Bowen Wang's avatar Bowen Wang Committed by GitHub
Browse files

[Feature] Expert Parallelism Load Balancer (EPLB) (#18343)


Signed-off-by: default avatarBowen Wang <abmfy@icloud.com>
parent 07b8fae2
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only DeepseekV2/DeepseekV3 model.""" """Inference-only DeepseekV2/DeepseekV3 model."""
from collections.abc import Iterable import typing
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -32,8 +33,10 @@ from transformers import PretrainedConfig ...@@ -32,8 +33,10 @@ from transformers import PretrainedConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -51,7 +54,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -51,7 +54,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import MixtureOfExperts, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -99,11 +102,17 @@ class DeepseekV2MoE(nn.Module): ...@@ -99,11 +102,17 @@ class DeepseekV2MoE(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
if config.hidden_act != "silu": if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. " raise ValueError(f"Unsupported activation: {config.hidden_act}. "
...@@ -120,6 +129,22 @@ class DeepseekV2MoE(nn.Module): ...@@ -120,6 +129,22 @@ class DeepseekV2MoE(nn.Module):
else: else:
self.gate.e_score_correction_bias = None self.gate.e_score_correction_bias = None
# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.enable_eplb = enable_eplb
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE( self.experts = FusedMoE(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -133,7 +158,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -133,7 +158,9 @@ class DeepseekV2MoE(nn.Module):
topk_group=config.topk_group, topk_group=config.topk_group,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias) e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
...@@ -503,6 +530,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -503,6 +530,7 @@ class DeepseekV2DecoderLayer(nn.Module):
model_config: ModelConfig, model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
enable_eplb: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -543,6 +571,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -543,6 +571,7 @@ class DeepseekV2DecoderLayer(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
) )
else: else:
self.mlp = DeepseekV2MLP( self.mlp = DeepseekV2MLP(
...@@ -615,6 +644,7 @@ class DeepseekV2Model(nn.Module): ...@@ -615,6 +644,7 @@ class DeepseekV2Model(nn.Module):
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
enable_eplb = vllm_config.parallel_config.enable_eplb
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -636,6 +666,7 @@ class DeepseekV2Model(nn.Module): ...@@ -636,6 +666,7 @@ class DeepseekV2Model(nn.Module):
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
enable_eplb=enable_eplb,
), ),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
...@@ -681,7 +712,7 @@ class DeepseekV2Model(nn.Module): ...@@ -681,7 +712,7 @@ class DeepseekV2Model(nn.Module):
return hidden_states return hidden_states
class DeepseekV2ForCausalLM(nn.Module, SupportsPP): class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -700,6 +731,44 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -700,6 +731,44 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self.expert_weights = []
# Set MoE hyperparameters
self.num_moe_layers = (config.num_hidden_layers -
config.first_k_dense_replace)
self.num_expert_groups = config.n_group
self.moe_layers: list[FusedMoE] = []
for layer in self.model.layers:
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
self.moe_layers.append(layer.mlp.experts)
# Pick last one layer since the first ones may be dense layers.
example_moe = typing.cast(
DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp)
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
...@@ -752,7 +821,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -752,7 +821,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts) num_experts=self.config.n_routed_experts,
num_redundant_experts=self.num_redundant_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
...@@ -789,24 +859,44 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -789,24 +859,44 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
is_expert_weight = False
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self): # Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue continue
param = params_dict[name] param = params_dict[name_mapped]
weight_loader = param.weight_loader # We should ask the weight loader to return success or not
weight_loader(param, # here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight, loaded_weight,
name, name_mapped,
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id) expert_id=expert_id,
return_success=True)
if success:
break break
else: else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -824,6 +914,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -824,6 +914,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, MutableSequence
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable) Union, overload, runtime_checkable)
...@@ -426,6 +427,73 @@ def is_hybrid( ...@@ -426,6 +427,73 @@ def is_hybrid(
return isinstance(model, IsHybrid) return isinstance(model, IsHybrid)
@runtime_checkable
class MixtureOfExperts(Protocol):
"""
Check if the model is a mixture of experts (MoE) model.
"""
expert_weights: MutableSequence[Iterable[Tensor]]
"""
Expert weights saved in this rank.
The first dimension is the layer, and the second dimension is different
parameters in the layer, e.g. up/down projection weights.
"""
num_moe_layers: int
"""Number of MoE layers in this model."""
num_expert_groups: int
"""Number of expert groups in this model."""
num_logical_experts: int
"""Number of logical experts in this model."""
num_physical_experts: int
"""Number of physical experts in this model."""
num_local_physical_experts: int
"""Number of local physical experts in this model."""
num_routed_experts: int
"""Number of routed experts in this model."""
num_shared_experts: int
"""Number of shared experts in this model."""
num_redundant_experts: int
"""Number of redundant experts in this model."""
def set_eplb_state(
self,
expert_load_view: Tensor,
logical_to_physical_map: Tensor,
logical_replica_count: Tensor,
) -> None:
"""
Register the EPLB state in the MoE model.
Since these are views of the actual EPLB state, any changes made by
the EPLB algorithm are automatically reflected in the model's behavior
without requiring additional method calls to set new states.
You should also collect model's `expert_weights` here instead of in
the weight loader, since after initial weight loading, further
processing like quantization may be applied to the weights.
Args:
expert_load_view: A view of the expert load metrics tensor.
logical_to_physical_map: Mapping from logical to physical experts.
logical_replica_count: Count of replicas for each logical expert.
"""
...
def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
return isinstance(model, MixtureOfExperts)
@runtime_checkable @runtime_checkable
class HasNoOps(Protocol): class HasNoOps(Protocol):
has_noops: ClassVar[Literal[True]] = True has_noops: ClassVar[Literal[True]] = True
......
...@@ -21,6 +21,7 @@ from vllm.attention.layer import Attention ...@@ -21,6 +21,7 @@ from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
...@@ -33,7 +34,8 @@ from vllm.logger import init_logger ...@@ -33,7 +34,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import has_step_pooler from vllm.model_executor.models.interfaces import (has_step_pooler,
is_mixture_of_experts)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
...@@ -150,6 +152,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -150,6 +152,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Sampler # Sampler
self.sampler = Sampler() self.sampler = Sampler()
self.eplb_state: Optional[EplbState] = None
"""
State of the expert parallelism load balancer.
Will be lazily initialized when the model is loaded.
"""
# Lazy initializations # Lazy initializations
# self.model: nn.Module # Set after load_model # self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache # Initialize in initialize_kv_cache
...@@ -1178,6 +1187,24 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1178,6 +1187,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items() for k, v in self.intermediate_tensors.items()
}) })
def eplb_step(self,
is_dummy: bool = False,
is_profile: bool = False) -> None:
"""
Step for the EPLB (Expert Parallelism Load Balancing) state.
"""
if not self.parallel_config.enable_eplb:
return
assert self.eplb_state is not None
assert is_mixture_of_experts(self.model)
self.eplb_state.step(
self.model,
is_dummy,
is_profile,
log_stats=self.parallel_config.eplb_log_balancedness,
)
def get_dp_padding(self, def get_dp_padding(self,
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
dp_size = self.vllm_config.parallel_config.data_parallel_size dp_size = self.vllm_config.parallel_config.data_parallel_size
...@@ -1595,6 +1622,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1595,6 +1622,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
self.eplb_step()
return ModelRunnerOutput( return ModelRunnerOutput(
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
...@@ -1729,6 +1758,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1729,6 +1758,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
time_after_load - time_before_load) time_after_load - time_before_load)
prepare_communication_buffer_for_model(self.model) prepare_communication_buffer_for_model(self.model)
if is_mixture_of_experts(
self.model) and self.parallel_config.enable_eplb:
logger.info("EPLB is enabled for model %s.",
self.model_config.model)
self.eplb_state = EplbState.build(
self.model,
self.device,
self.parallel_config,
)
def save_tensorized_model( def save_tensorized_model(
self, self,
tensorizer_config: "TensorizerConfig", tensorizer_config: "TensorizerConfig",
...@@ -1887,6 +1926,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1887,6 +1926,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, self,
num_tokens: int, num_tokens: int,
capture_attn_cudagraph: bool = False, capture_attn_cudagraph: bool = False,
skip_eplb: bool = False,
is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Padding for DP # Padding for DP
...@@ -1983,6 +2024,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1983,6 +2024,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens) self.drafter.dummy_run(num_tokens)
# This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real
# requests to process.
# However, in DP settings, there may be cases when some DP ranks do
# not have any requests to process, so they're executing dummy batches.
# In such cases, we still have to trigger EPLB to make sure
# ranks execute the rearrangement in synchronization.
if not skip_eplb:
self.eplb_step(is_dummy=True, is_profile=is_profile)
logit_indices = np.cumsum(num_scheduled_tokens) - 1 logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states, hidden_states[logit_indices] return hidden_states, hidden_states[logit_indices]
...@@ -2175,8 +2226,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2175,8 +2226,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Cache the dummy encoder outputs. # Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states \ hidden_states, last_hidden_states \
= self._dummy_run(self.max_num_tokens) = self._dummy_run(self.max_num_tokens, is_profile=True)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
if self.is_pooling_model: if self.is_pooling_model:
output = self._dummy_pooler_run(hidden_states) output = self._dummy_pooler_run(hidden_states)
...@@ -2210,10 +2262,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2210,10 +2262,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes),
desc="Capturing CUDA graphs", desc="Capturing CUDA graphs",
total=len(self.cudagraph_batch_sizes)): total=len(self.cudagraph_batch_sizes)):
# We skip EPLB here since we don't want to record dummy metrics
for _ in range( for _ in range(
self.compilation_config.cudagraph_num_of_warmups): self.compilation_config.cudagraph_num_of_warmups):
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) self._dummy_run(num_tokens,
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) capture_attn_cudagraph=full_cg,
skip_eplb=True)
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
end_time = time.perf_counter() end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]
......
...@@ -259,9 +259,10 @@ class Worker(WorkerBase): ...@@ -259,9 +259,10 @@ class Worker(WorkerBase):
x for x in warmup_sizes if x not in x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes self.vllm_config.compilation_config.cudagraph_capture_sizes
] ]
# We skip EPLB here since we don't want to record dummy metrics
for size in sorted(warmup_sizes, reverse=True): for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size) logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size) self.model_runner._dummy_run(size, skip_eplb=True)
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
self.model_runner.capture_model() self.model_runner.capture_model()
...@@ -274,8 +275,12 @@ class Worker(WorkerBase): ...@@ -274,8 +275,12 @@ class Worker(WorkerBase):
max_num_reqs = min(self.scheduler_config.max_num_seqs, max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
# We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \ hidden_states, last_hidden_states = \
self.model_runner._dummy_run(num_tokens=max_num_reqs) self.model_runner._dummy_run(
num_tokens=max_num_reqs,
skip_eplb=True,
)
if self.model_runner.is_pooling_model: if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states) self.model_runner._dummy_pooler_run(hidden_states)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment