Commit d0dafaf4 authored by 王敏's avatar 王敏
Browse files

[feat]添加ep moe功能

parent a27fdb55
...@@ -41,7 +41,8 @@ def benchmark_config( ...@@ -41,7 +41,8 @@ def benchmark_config(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
nn_moe: Optional[bool] = False nn_moe: Optional[bool] = False,
moe_ep_size: int = 1,
) -> float: ) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
...@@ -140,6 +141,9 @@ def benchmark_config( ...@@ -140,6 +141,9 @@ def benchmark_config(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
use_nn_moe=nn_moe, use_nn_moe=nn_moe,
moe_ep_size=moe_ep_size,
start_expert=0,
end_expert=num_experts
) )
# JIT compilation & warmup # JIT compilation & warmup
...@@ -406,7 +410,8 @@ class BenchmarkWorker: ...@@ -406,7 +410,8 @@ class BenchmarkWorker:
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
search_space: List[Dict[str, int]], search_space: List[Dict[str, int]],
nn_moe: Optional[bool] = False nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1
) -> Dict[str, int]: ) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
...@@ -430,7 +435,8 @@ class BenchmarkWorker: ...@@ -430,7 +435,8 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=20, num_iters=20,
nn_moe=nn_moe) nn_moe=nn_moe,
moe_ep_size=moe_ep_size)
except triton.runtime.autotuner.OutOfResources: except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile. # Some configurations may be invalid and fail to compile.
continue continue
...@@ -520,29 +526,44 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int, ...@@ -520,29 +526,44 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
moe_ep_size = args.moe_ep_size
tp_size = args.tp_size
if moe_ep_size > 1:
tp_size = tp_size // moe_ep_size
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code) args.model, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "DbrxForCausalLM": if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
E = E // moe_ep_size
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM": elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts E = config.num_experts
E = E // moe_ep_size
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM": elif config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM":
E = config.n_routed_experts E = config.n_routed_experts
E = E // moe_ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
E = E // moe_ep_size
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
else: else:
# Default: Mixtral. # Default: Mixtral.
E = config.num_local_experts E = config.num_local_experts
E = E // moe_ep_size
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
...@@ -582,7 +603,7 @@ def main(args: argparse.Namespace): ...@@ -582,7 +603,7 @@ def main(args: argparse.Namespace):
start = time.time() start = time.time()
configs = _distribute( configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size, "tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, args.nn_moe) topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, args.nn_moe, moe_ep_size)
for batch_size in batch_sizes]) for batch_size in batch_sizes])
best_configs = { best_configs = {
M: sort_config(config) M: sort_config(config)
...@@ -622,6 +643,7 @@ if __name__ == "__main__": ...@@ -622,6 +643,7 @@ if __name__ == "__main__":
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument("--nn_moe", type=bool, default=True) parser.add_argument("--nn_moe", type=bool, default=True)
parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--moe-ep-size", type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -279,6 +279,103 @@ __global__ void moe_sum_kernel( ...@@ -279,6 +279,103 @@ __global__ void moe_sum_kernel(
} }
} }
template <typename scalar_t, typename token_cnts_t>
__global__ void ep_moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t* sorted_token_ids,
int32_t* expert_ids,
int32_t* total_tokens_post_pad,
int32_t num_experts,
int32_t block_size, size_t numel,
int32_t start_expert, int32_t end_expert) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
token_cnts_t* tokens_cnts =
(token_cnts_t*)(shared_mem + num_experts +
1); // 2d tensor with shape (blockDim.x + 1, num_experts)
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
if (topk_ids[i] >= start_expert && topk_ids[i] < end_expert) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i] - start_expert)];
}
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
if (expert_id >= start_expert && expert_id < end_expert) {
expert_id -= start_expert;
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}
}
} // namespace moe } // namespace moe
} // namespace vllm } // namespace vllm
...@@ -371,6 +468,109 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -371,6 +468,109 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
} }
} }
void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
int64_t start_expert, int64_t end_expert) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem_i32 =
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
const int32_t shared_mem_i16 =
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
(num_experts + 1) * sizeof(int32_t);
// bool use_global_memory = false;
// bool use_i16 = false; // Use uint16_t for shared memory token counts
// if (shared_mem_i32 < device_max_shared_mem) {
// // Do nothing in this case. We're all set to use int32_t token counts
// } else if (shared_mem_i16 < device_max_shared_mem &&
// topk_ids.numel() <= 65535) {
// // when nelements of topk_ids is smaller than 65535 (max value of uint16),
// // element value of token_cnts would also smaller than 65535,
// // so we can use uint16 as dtype of token_cnts
// use_i16 = true;
// } else {
// use_global_memory = true;
// }
// if (use_global_memory) {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// // tensors
// const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
// auto options_int = torch::TensorOptions()
// .dtype(torch::kInt)
// .device(topk_ids.device());
// torch::Tensor token_cnts_buffer =
// torch::empty({(num_experts + 1) * num_experts}, options_int);
// torch::Tensor cumsum_buffer =
// torch::empty({num_experts + 1}, options_int);
// auto kernel =
// vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
// kernel<<<1, num_thread, 0, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
// cumsum_buffer.data_ptr<int32_t>());
// });
// } else if (use_i16) {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// // set dynamic shared mem
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i16));
// kernel<<<1, num_thread, shared_mem_i16, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// } else {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i32));
// kernel<<<1, num_thread, shared_mem_i32, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// }
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "ep_moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::ep_moe_align_block_size_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem_i32));
kernel<<<1, num_thread, shared_mem_i32, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), start_expert, end_expert);
});
}
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, int64_t block_size,
torch::Tensor sorted_token_ids, torch::Tensor sorted_token_ids,
......
...@@ -13,6 +13,12 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -13,6 +13,12 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
int64_t start_expert, int64_t end_expert);
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, int64_t block_size,
torch::Tensor sorted_token_ids, torch::Tensor sorted_token_ids,
......
...@@ -22,6 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -22,6 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"); " Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
ops.def(
"ep_moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad,"
" int start_expert, int end_expert) -> ()");
ops.impl("ep_moe_align_block_size", torch::kCUDA, &ep_moe_align_block_size);
// temporarily adapted from // temporarily adapted from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
m.def( m.def(
......
...@@ -1378,6 +1378,16 @@ def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, ...@@ -1378,6 +1378,16 @@ def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts, torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts,
block_size, sorted_token_ids, block_size, sorted_token_ids,
experts_ids, num_tokens_post_pad) experts_ids, num_tokens_post_pad)
def ep_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
start_expert, end_expert) -> None:
torch.ops._C.ep_moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad, start_expert,
end_expert)
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
......
...@@ -1342,6 +1342,8 @@ class ParallelConfig: ...@@ -1342,6 +1342,8 @@ class ParallelConfig:
rank: int = 0 rank: int = 0
moe_ep_size: Optional[int] = 1
def compute_hash(self): def compute_hash(self):
""" """
Provide a hash that uniquely identifies all the configs Provide a hash that uniquely identifies all the configs
......
...@@ -206,6 +206,8 @@ class EngineArgs: ...@@ -206,6 +206,8 @@ class EngineArgs:
calculate_kv_scales: Optional[bool] = None calculate_kv_scales: Optional[bool] = None
moe_ep_size: int = 1
def __post_init__(self): def __post_init__(self):
if not self.tokenizer: if not self.tokenizer:
self.tokenizer = self.model self.tokenizer = self.model
...@@ -417,6 +419,10 @@ class EngineArgs: ...@@ -417,6 +419,10 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.') help='Number of tensor parallel replicas.')
parser.add_argument('--moe-ep-size',
type=int,
default=EngineArgs.moe_ep_size,
help='Number of moe expert parallel replicas.')
parser.add_argument( parser.add_argument(
'--max-parallel-loading-workers', '--max-parallel-loading-workers',
type=int, type=int,
...@@ -1123,6 +1129,7 @@ class EngineArgs: ...@@ -1123,6 +1129,7 @@ class EngineArgs:
ray_workers_use_nsight=self.ray_workers_use_nsight, ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend, distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls, worker_cls=self.worker_cls,
moe_ep_size=self.moe_ep_size,
) )
max_model_len = model_config.max_model_len max_model_len = model_config.max_model_len
......
...@@ -633,6 +633,66 @@ def moe_align_block_size( ...@@ -633,6 +633,66 @@ def moe_align_block_size(
return sorted_ids, expert_ids, num_tokens_post_pad return sorted_ids, expert_ids, num_tokens_post_pad
def moe_ep_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int, start_expert: int,
end_expert: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.ep_moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, start_expert,
end_expert)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, def invoke_fused_moe_kernel(A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, C: torch.Tensor,
...@@ -1029,11 +1089,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1029,11 +1089,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape,
use_nn_moe) use_nn_moe, moe_ep_size=moe_ep_size,
start_expert=start_expert, end_expert=end_expert)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1052,7 +1116,10 @@ def inplace_fused_experts_fake( ...@@ -1052,7 +1116,10 @@ def inplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1) -> None:
pass pass
...@@ -1080,12 +1147,16 @@ def outplace_fused_experts( ...@@ -1080,12 +1147,16 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, use_fp8_w8a8, use_int8_w8a16, False, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, a1_scale, a2_scale, block_shape,
use_nn_moe) use_nn_moe, moe_ep_size=moe_ep_size,
start_expert=start_expert, end_expert=end_expert)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -1104,7 +1175,10 @@ def outplace_fused_experts_fake( ...@@ -1104,7 +1175,10 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1132,7 +1206,10 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1132,7 +1206,10 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False): use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1):
if inplace: if inplace:
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
topk_weights, topk_ids, topk_weights, topk_ids,
...@@ -1140,14 +1217,19 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1140,14 +1217,19 @@ def fused_experts(hidden_states: torch.Tensor,
use_int4_w4a16, w1_scale, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, a2_scale, block_shape,
use_nn_moe) use_nn_moe,
moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert)
return hidden_states return hidden_states
else: else:
return torch.ops.vllm.outplace_fused_experts( return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, a1_scale, a2_scale, block_shape,
use_nn_moe) use_nn_moe, moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert)
def fused_experts_impl(hidden_states: torch.Tensor, def fused_experts_impl(hidden_states: torch.Tensor,
...@@ -1166,7 +1248,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1166,7 +1248,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False): use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1):
# Check constraints. # Check constraints.
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[ assert hidden_states.shape[1] // 2 == w1.shape[
...@@ -1219,6 +1304,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1219,6 +1304,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1] if not use_nn_moe else w2.shape[2]), intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1] if not use_nn_moe else w2.shape[2]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
if moe_ep_size > 1:
intermediate_cache3.zero_()
if hidden_states.dtype == torch.bfloat16: if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16 compute_type = tl.bfloat16
...@@ -1259,6 +1347,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1259,6 +1347,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids, expert_ids, num_tokens_post_padded = ( sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
if moe_ep_size == 1:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
else:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_ep_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E,
start_expert, end_expert))
invoke_fused_moe_kernel(curr_hidden_states, invoke_fused_moe_kernel(curr_hidden_states,
w1, w1,
...@@ -1333,6 +1429,9 @@ def fused_moe( ...@@ -1333,6 +1429,9 @@ def fused_moe(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = None,
start_expert: Optional[int] = None,
end_expert: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1405,4 +1504,7 @@ def fused_moe( ...@@ -1405,4 +1504,7 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert)
...@@ -58,7 +58,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -58,7 +58,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -134,7 +134,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -134,7 +134,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward(x=x, return self.forward(x=x,
layer=layer, layer=layer,
...@@ -147,7 +150,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -147,7 +150,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert)
def forward_cuda( def forward_cuda(
self, self,
...@@ -162,7 +168,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -162,7 +168,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -182,7 +191,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -182,7 +191,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert)
def forward_cpu( def forward_cpu(
self, self,
...@@ -221,7 +233,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -221,7 +233,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
...@@ -282,6 +297,7 @@ class FusedMoE(torch.nn.Module): ...@@ -282,6 +297,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
moe_ep_size: Optional[int] = 1,
): ):
super().__init__() super().__init__()
...@@ -305,6 +321,17 @@ class FusedMoE(torch.nn.Module): ...@@ -305,6 +321,17 @@ class FusedMoE(torch.nn.Module):
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.tp_rank = get_tensor_model_parallel_rank()
self.moe_ep_size = moe_ep_size
self.moe_tp_rank = self.tp_rank // self.moe_ep_size
self.moe_tp_size = self.tp_size // self.moe_ep_size
if self.moe_ep_size > 1:
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.moe_ep_rank = self.tp_rank % self.moe_ep_size
num_experts_per_node = num_experts // self.moe_ep_size
self.start_expert = num_experts_per_node * self.moe_ep_rank
self.end_expert = self.start_expert + num_experts_per_node
if self.scoring_func != "softmax" and not self.use_grouped_topk: if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for " raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.") "non-grouped topk.")
...@@ -323,7 +350,7 @@ class FusedMoE(torch.nn.Module): ...@@ -323,7 +350,7 @@ class FusedMoE(torch.nn.Module):
self.use_nn_moe = False self.use_nn_moe = False
moe_quant_params = { moe_quant_params = {
"num_experts": num_experts, "num_experts": num_experts if self.moe_ep_size == 1 else num_experts_per_node,
"hidden_size": hidden_size, "hidden_size": hidden_size,
"intermediate_size_per_partition": "intermediate_size_per_partition":
self.intermediate_size_per_partition, self.intermediate_size_per_partition,
...@@ -489,8 +516,10 @@ class FusedMoE(torch.nn.Module): ...@@ -489,8 +516,10 @@ class FusedMoE(torch.nn.Module):
# dimension intermediate_size_per_partition is used. # dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_id = expert_id - self.start_expert
expert_data = param.data[expert_id] expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
tp_rank = tp_rank // self.moe_ep_size
# is_transposed: if the dim to shard the weight # is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors # should be flipped. Required by GPTQ, compressed-tensors
...@@ -638,7 +667,10 @@ class FusedMoE(torch.nn.Module): ...@@ -638,7 +667,10 @@ class FusedMoE(torch.nn.Module):
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
use_nn_moe=self.use_nn_moe) use_nn_moe=self.use_nn_moe,
moe_ep_size=self.moe_ep_size,
start_expert=self.start_expert,
end_expert=self.end_expert)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
...@@ -663,6 +695,33 @@ class FusedMoE(torch.nn.Module): ...@@ -663,6 +695,33 @@ class FusedMoE(torch.nn.Module):
("w3", ckpt_up_proj_name), ("w3", ckpt_up_proj_name),
] ]
] ]
@classmethod
def make_expert_params_mapping_ep(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
moe_ep_size) -> List[Tuple[str, str, int, str]]:
# tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# moe_tp_rank = tp_rank // moe_ep_size
moe_ep_rank = tp_rank % moe_ep_size
experts_per_rank = num_experts // moe_ep_size
experts_range = range(moe_ep_rank * experts_per_rank,
(moe_ep_rank + 1) * experts_per_rank)
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
for expert_id in experts_range for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
def _load_fp8_scale(self, param: torch.nn.Parameter, def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str, loaded_weight: torch.Tensor, weight_name: str,
......
...@@ -149,6 +149,8 @@ def _initialize_model( ...@@ -149,6 +149,8 @@ def _initialize_model(
kwargs["lora_config"] = vllm_config.lora_config kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params: if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config kwargs["scheduler_config"] = vllm_config.scheduler_config
if "parallel_config" in all_params:
kwargs["parallel_config"] = vllm_config.parallel_config
with set_current_vllm_config(vllm_config, check_compile=True): with set_current_vllm_config(vllm_config, check_compile=True):
return model_class(**kwargs) return model_class(**kwargs)
......
...@@ -29,7 +29,7 @@ from torch import nn ...@@ -29,7 +29,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig, ParallelConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -99,6 +99,7 @@ class DeepseekV3MoE(nn.Module): ...@@ -99,6 +99,7 @@ class DeepseekV3MoE(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
moe_ep_size: int = 1
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
...@@ -138,7 +139,8 @@ class DeepseekV3MoE(nn.Module): ...@@ -138,7 +139,8 @@ class DeepseekV3MoE(nn.Module):
topk_group=config.topk_group, topk_group=config.topk_group,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias) e_score_correction_bias=self.gate.e_score_correction_bias,
moe_ep_size=moe_ep_size)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
...@@ -488,6 +490,7 @@ class DeepseekV3DecoderLayer(nn.Module): ...@@ -488,6 +490,7 @@ class DeepseekV3DecoderLayer(nn.Module):
model_config: ModelConfig, model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
moe_ep_size : int = 1,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -526,6 +529,7 @@ class DeepseekV3DecoderLayer(nn.Module): ...@@ -526,6 +529,7 @@ class DeepseekV3DecoderLayer(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
moe_ep_size=moe_ep_size
) )
else: else:
self.mlp = DeepseekV3MLP( self.mlp = DeepseekV3MLP(
...@@ -575,7 +579,7 @@ class DeepseekV3Model(nn.Module): ...@@ -575,7 +579,7 @@ class DeepseekV3Model(nn.Module):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", moe_ep_size: int = 1):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -602,6 +606,7 @@ class DeepseekV3Model(nn.Module): ...@@ -602,6 +606,7 @@ class DeepseekV3Model(nn.Module):
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
moe_ep_size=moe_ep_size
), ),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
...@@ -660,8 +665,12 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP): ...@@ -660,8 +665,12 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.parallel_config = vllm_config.parallel_config
self.moe_ep_size = self.parallel_config.moe_ep_size
self.model = DeepseekV3Model(vllm_config=vllm_config, self.model = DeepseekV3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"),
moe_ep_size=self.moe_ep_size)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
...@@ -737,11 +746,19 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP): ...@@ -737,11 +746,19 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( if self.moe_ep_size == 1:
ckpt_gate_proj_name="gate_proj", expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_down_proj_name="down_proj", ckpt_gate_proj_name="gate_proj",
ckpt_up_proj_name="up_proj", ckpt_down_proj_name="down_proj",
num_experts=self.config.n_routed_experts) ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
else:
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
moe_ep_size=self.moe_ep_size)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
...@@ -806,6 +823,10 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP): ...@@ -806,6 +823,10 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
# Skip loading extra expert weights for ep moe mode
if name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
......
...@@ -75,7 +75,8 @@ class MixtralMoE(nn.Module): ...@@ -75,7 +75,8 @@ class MixtralMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = ""): prefix: str = "",
moe_ep_size: int = 1):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -97,7 +98,8 @@ class MixtralMoE(nn.Module): ...@@ -97,7 +98,8 @@ class MixtralMoE(nn.Module):
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
prefix=f"{prefix}.experts") prefix=f"{prefix}.experts",
moe_ep_size=moe_ep_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
...@@ -198,6 +200,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -198,6 +200,7 @@ class MixtralDecoderLayer(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
moe_ep_size : int = 1,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -218,7 +221,8 @@ class MixtralDecoderLayer(nn.Module): ...@@ -218,7 +221,8 @@ class MixtralDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe") prefix=f"{prefix}.block_sparse_moe",
moe_ep_size=moe_ep_size)
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
...@@ -256,7 +260,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -256,7 +260,7 @@ class MixtralDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class MixtralModel(nn.Module): class MixtralModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", moe_ep_size: int = 1):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -279,7 +283,7 @@ class MixtralModel(nn.Module): ...@@ -279,7 +283,7 @@ class MixtralModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer( lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix config, cache_config, quant_config=quant_config, prefix=prefix, moe_ep_size=moe_ep_size
), ),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
...@@ -355,8 +359,12 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -355,8 +359,12 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.parallel_config = vllm_config.parallel_config
self.moe_ep_size = self.parallel_config.moe_ep_size
self.model = MixtralModel(vllm_config=vllm_config, self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"),
moe_ep_size=self.moe_ep_size)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
...@@ -430,11 +438,19 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -430,11 +438,19 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( if self.moe_ep_size == 1:
ckpt_gate_proj_name="w1", expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_down_proj_name="w2", ckpt_gate_proj_name="w1",
ckpt_up_proj_name="w3", ckpt_down_proj_name="w2",
num_experts=self.config.num_local_experts) ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)
else:
expert_params_mapping = FusedMoE.make_expert_params_mapping_ep(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts,
moe_ep_size=self.moe_ep_size)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
...@@ -486,6 +502,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -486,6 +502,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if ((name.endswith(".bias") or name.endswith("_bias")) if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict): and name not in params_dict):
continue continue
# Skip loading extra expert weights for ep moe mode
if name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment