Commit 3f5c2eea authored by zhuwenwen's avatar zhuwenwen
Browse files

add mla tpsp and moe share experts computation communication overlap

parent 8375370f
...@@ -178,6 +178,8 @@ if TYPE_CHECKING: ...@@ -178,6 +178,8 @@ if TYPE_CHECKING:
VLLM_P2P_BUF_TOKENS: int = 30000 VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_ENABLE_MLA_SP: bool = False
VLLM_ENABLE_MLA_QKV_MERGE: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1094,68 +1096,89 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1094,68 +1096,89 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_PA": "VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use apex for rmsnorm # vLLM will use apex for rmsnorm
"VLLM_USE_APEX_RN": "VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13": "VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop for deepseek-v3 # vLLM will use lightop for deepseek-v3
"VLLM_USE_LIGHTOP": "VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use elenmentwise not triton_ # vLLM will use elenmentwise not triton_
"VLLM_USE_OPT_ZEROS": "VLLM_USE_OPT_ZEROS":
lambda: (os.environ.get("VLLM_USE_OPT_ZEROS", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_ZEROS", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt cat for deepseek-v3 # vLLM will use opt cat for deepseek-v3
"VLLM_USE_OPT_CAT": "VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use triton moe_sum # vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM": "VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_sum_mul_add # vLLM will use lightop moe_sum_mul_add
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD": "VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_sum # vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM": "VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_align_block_size # vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN": "VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt merge_aatn_states, not triton # vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
# vllm will use rmsquant fused op # vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT": "USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
# vllm will use silu_mul_quant fused op # vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT": "USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
# vllm pd separation will be used async # vllm pd separation will be used async
"VLLM_P2P_ASYNC": "VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))), lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
# pd separation p2p async buf tokens # pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS": "VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")), lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# vllm will enable minimal injection for pipeline parallel scheduling # vllm will enable minimal injection for pipeline parallel scheduling
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION": "VLLM_SCHED_ENABLE_MINIMAL_INJECTION":
lambda: (os.getenv("VLLM_SCHED_ENABLE_MINIMAL_INJECTION", "0").lower() in lambda: (os.getenv("VLLM_SCHED_ENABLE_MINIMAL_INJECTION", "0").lower() in
("true", "1")), ("true", "1")),
# vLLM will split prefill and decode, not mix up # vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT": "VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")), ("true", "1")),
"VLLM_ENABLE_MLA_SP":
lambda: bool(int(os.getenv("VLLM_ENABLE_MLA_SP", "0"))),
"VLLM_ENABLE_MLA_QKV_MERGE":
lambda: bool(int(os.getenv("VLLM_ENABLE_MLA_QKV_MERGE", "0"))),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -637,6 +637,13 @@ def determine_expert_map( ...@@ -637,6 +637,13 @@ def determine_expert_map(
return (local_num_experts, expert_map) return (local_num_experts, expert_map)
EventType = Enum(
'EventType',
['Main', 'Attention', 'QCAllgather', 'KVFinish', 'MoeShared', 'MoeChunkingOverlap', 'MoeAllgather', 'MoeReduceScatter'],
start=0,
)
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
......
...@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
...@@ -454,6 +455,86 @@ class ReplicatedLinear(LinearBase): ...@@ -454,6 +455,86 @@ class ReplicatedLinear(LinearBase):
return s return s
class MergedReplicatedLinear(ReplicatedLinear):
"""Merged replicated linear layer
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes
super().__init__(input_size,
sum(output_sizes),
bias,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
def weight_loader(self,
param: Union[Parameter, BasevLLMParameter],
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
assert loaded_shard_id is not None
assert loaded_shard_id < len(self.output_sizes)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method, (Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
)
shard_size = (
(self.output_sizes[loaded_shard_id] + block_n - 1)
// block_n
)
elif isinstance(param, PerTensorScaleParameter) and current_platform.is_rocm():
shard_offset = loaded_shard_id
shard_size = 1
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
start_offset = shard_offset
end_offset = start_offset + shard_size
assert loaded_weight.shape == param.data[start_offset:end_offset, ...].shape, (
f"Expected shape {param.data[start_offset:end_offset, ...].shape}, got {loaded_weight.shape}"
)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
param.data[start_offset:end_offset, ...].copy_(loaded_weight)
class ColumnParallelLinear(LinearBase): class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -1390,6 +1471,7 @@ class RowParallelLinear(LinearBase): ...@@ -1390,6 +1471,7 @@ class RowParallelLinear(LinearBase):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
sp_parallel: bool = False,
): ):
# Divide the weight matrix along the first dimension. # Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -1397,6 +1479,7 @@ class RowParallelLinear(LinearBase): ...@@ -1397,6 +1479,7 @@ class RowParallelLinear(LinearBase):
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
self.sp_parallel = sp_parallel
super().__init__(input_size, super().__init__(input_size,
output_size, output_size,
...@@ -1526,7 +1609,10 @@ class RowParallelLinear(LinearBase): ...@@ -1526,7 +1609,10 @@ class RowParallelLinear(LinearBase):
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel) output = self.tbo_all_reduce(output_parallel)
else: else:
output = tensor_model_parallel_all_reduce(output_parallel) if self.sp_parallel:
output = tensor_model_parallel_reduce_scatter(output_parallel.contiguous(), dim=0)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output = output_parallel output = output_parallel
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment