Commit d0dafaf4 authored by 王敏's avatar 王敏
Browse files

[feat]添加ep moe功能

parent a27fdb55
......@@ -41,7 +41,8 @@ def benchmark_config(
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
nn_moe: Optional[bool] = False
nn_moe: Optional[bool] = False,
moe_ep_size: int = 1,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
......@@ -140,6 +141,9 @@ def benchmark_config(
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=nn_moe,
moe_ep_size=moe_ep_size,
start_expert=0,
end_expert=num_experts
)
# JIT compilation & warmup
......@@ -406,7 +410,8 @@ class BenchmarkWorker:
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
search_space: List[Dict[str, int]],
nn_moe: Optional[bool] = False
nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
......@@ -430,7 +435,8 @@ class BenchmarkWorker:
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20,
nn_moe=nn_moe)
nn_moe=nn_moe,
moe_ep_size=moe_ep_size)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
......@@ -520,29 +526,44 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
def main(args: argparse.Namespace):
print(args)
moe_ep_size = args.moe_ep_size
tp_size = args.tp_size
if moe_ep_size > 1:
tp_size = tp_size // moe_ep_size
config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
E = E // moe_ep_size
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
E = E // moe_ep_size
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM":
E = config.n_routed_experts
E = E // moe_ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
E = E // moe_ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral.
E = config.num_local_experts
E = E // moe_ep_size
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = 2 * intermediate_size // tp_size
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
......@@ -582,7 +603,7 @@ def main(args: argparse.Namespace):
start = time.time()
configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, args.nn_moe)
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, args.nn_moe, moe_ep_size)
for batch_size in batch_sizes])
best_configs = {
M: sort_config(config)
......@@ -622,6 +643,7 @@ if __name__ == "__main__":
parser.add_argument("--tune", action="store_true")
parser.add_argument("--nn_moe", type=bool, default=True)
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--moe-ep-size", type=int, default=1)
args = parser.parse_args()
main(args)
......@@ -279,6 +279,103 @@ __global__ void moe_sum_kernel(
}
}
template <typename scalar_t, typename token_cnts_t>
__global__ void ep_moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t* sorted_token_ids,
int32_t* expert_ids,
int32_t* total_tokens_post_pad,
int32_t num_experts,
int32_t block_size, size_t numel,
int32_t start_expert, int32_t end_expert) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
token_cnts_t* tokens_cnts =
(token_cnts_t*)(shared_mem + num_experts +
1); // 2d tensor with shape (blockDim.x + 1, num_experts)
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
if (topk_ids[i] >= start_expert && topk_ids[i] < end_expert) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i] - start_expert)];
}
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
if (expert_id >= start_expert && expert_id < end_expert) {
expert_id -= start_expert;
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}
}
} // namespace moe
} // namespace vllm
......@@ -371,6 +468,109 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
}
void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
int64_t start_expert, int64_t end_expert) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem_i32 =
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
const int32_t shared_mem_i16 =
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
(num_experts + 1) * sizeof(int32_t);
// bool use_global_memory = false;
// bool use_i16 = false; // Use uint16_t for shared memory token counts
// if (shared_mem_i32 < device_max_shared_mem) {
// // Do nothing in this case. We're all set to use int32_t token counts
// } else if (shared_mem_i16 < device_max_shared_mem &&
// topk_ids.numel() <= 65535) {
// // when nelements of topk_ids is smaller than 65535 (max value of uint16),
// // element value of token_cnts would also smaller than 65535,
// // so we can use uint16 as dtype of token_cnts
// use_i16 = true;
// } else {
// use_global_memory = true;
// }
// if (use_global_memory) {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// // tensors
// const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
// auto options_int = torch::TensorOptions()
// .dtype(torch::kInt)
// .device(topk_ids.device());
// torch::Tensor token_cnts_buffer =
// torch::empty({(num_experts + 1) * num_experts}, options_int);
// torch::Tensor cumsum_buffer =
// torch::empty({num_experts + 1}, options_int);
// auto kernel =
// vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
// kernel<<<1, num_thread, 0, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
// cumsum_buffer.data_ptr<int32_t>());
// });
// } else if (use_i16) {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// // set dynamic shared mem
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i16));
// kernel<<<1, num_thread, shared_mem_i16, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// } else {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i32));
// kernel<<<1, num_thread, shared_mem_i32, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// }
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "ep_moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::ep_moe_align_block_size_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem_i32));
kernel<<<1, num_thread, shared_mem_i32, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), start_expert, end_expert);
});
}
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
......
......@@ -13,6 +13,12 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
int64_t start_expert, int64_t end_expert);
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
......
......@@ -22,6 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
ops.def(
"ep_moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad,"
" int start_expert, int end_expert) -> ()");
ops.impl("ep_moe_align_block_size", torch::kCUDA, &ep_moe_align_block_size);
// temporarily adapted from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
m.def(
......
......@@ -1379,6 +1379,16 @@ def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size, sorted_token_ids,
experts_ids, num_tokens_post_pad)
def ep_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
start_expert, end_expert) -> None:
torch.ops._C.ep_moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad, start_expert,
end_expert)
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies: torch.Tensor,
......
......@@ -1342,6 +1342,8 @@ class ParallelConfig:
rank: int = 0
moe_ep_size: Optional[int] = 1
def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs
......
......@@ -206,6 +206,8 @@ class EngineArgs:
calculate_kv_scales: Optional[bool] = None
moe_ep_size: int = 1
def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
......@@ -417,6 +419,10 @@ class EngineArgs:
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument('--moe-ep-size',
type=int,
default=EngineArgs.moe_ep_size,
help='Number of moe expert parallel replicas.')
parser.add_argument(
'--max-parallel-loading-workers',
type=int,
......@@ -1123,6 +1129,7 @@ class EngineArgs:
ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
moe_ep_size=self.moe_ep_size,
)
max_model_len = model_config.max_model_len
......
......@@ -633,6 +633,66 @@ def moe_align_block_size(
return sorted_ids, expert_ids, num_tokens_post_pad
def moe_ep_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int, start_expert: int,
end_expert: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.ep_moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, start_expert,
end_expert)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
......@@ -1029,11 +1089,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape,
use_nn_moe)
use_nn_moe, moe_ep_size=moe_ep_size,
start_expert=start_expert, end_expert=end_expert)
def inplace_fused_experts_fake(
......@@ -1052,7 +1116,10 @@ def inplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1) -> None:
pass
......@@ -1080,12 +1147,16 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape,
use_nn_moe)
use_nn_moe, moe_ep_size=moe_ep_size,
start_expert=start_expert, end_expert=end_expert)
def outplace_fused_experts_fake(
......@@ -1104,7 +1175,10 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1132,7 +1206,10 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False):
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1):
if inplace:
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
topk_weights, topk_ids,
......@@ -1140,14 +1217,19 @@ def fused_experts(hidden_states: torch.Tensor,
use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape,
use_nn_moe)
use_nn_moe,
moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert)
return hidden_states
else:
return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape,
use_nn_moe)
use_nn_moe, moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert)
def fused_experts_impl(hidden_states: torch.Tensor,
......@@ -1166,7 +1248,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False):
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1):
# Check constraints.
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
......@@ -1220,6 +1305,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
device=hidden_states.device,
dtype=hidden_states.dtype)
if moe_ep_size > 1:
intermediate_cache3.zero_()
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
......@@ -1260,6 +1348,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
if moe_ep_size == 1:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
else:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_ep_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E,
start_expert, end_expert))
invoke_fused_moe_kernel(curr_hidden_states,
w1,
intermediate_cache1,
......@@ -1333,6 +1429,9 @@ def fused_moe(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = None,
start_expert: Optional[int] = None,
end_expert: Optional[int] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -1405,4 +1504,7 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
use_nn_moe=use_nn_moe,
moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert)
......@@ -58,7 +58,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
......@@ -134,7 +134,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
......@@ -147,7 +150,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
use_nn_moe=use_nn_moe)
use_nn_moe=use_nn_moe,
moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert)
def forward_cuda(
self,
......@@ -162,7 +168,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......@@ -182,7 +191,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_nn_moe=use_nn_moe)
use_nn_moe=use_nn_moe,
moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert)
def forward_cpu(
self,
......@@ -221,7 +233,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
......@@ -282,6 +297,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
moe_ep_size: Optional[int] = 1,
):
super().__init__()
......@@ -305,6 +321,17 @@ class FusedMoE(torch.nn.Module):
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.tp_rank = get_tensor_model_parallel_rank()
self.moe_ep_size = moe_ep_size
self.moe_tp_rank = self.tp_rank // self.moe_ep_size
self.moe_tp_size = self.tp_size // self.moe_ep_size
if self.moe_ep_size > 1:
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.moe_ep_rank = self.tp_rank % self.moe_ep_size
num_experts_per_node = num_experts // self.moe_ep_size
self.start_expert = num_experts_per_node * self.moe_ep_rank
self.end_expert = self.start_expert + num_experts_per_node
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
......@@ -323,7 +350,7 @@ class FusedMoE(torch.nn.Module):
self.use_nn_moe = False
moe_quant_params = {
"num_experts": num_experts,
"num_experts": num_experts if self.moe_ep_size == 1 else num_experts_per_node,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
......@@ -489,8 +516,10 @@ class FusedMoE(torch.nn.Module):
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_id = expert_id - self.start_expert
expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()
tp_rank = tp_rank // self.moe_ep_size
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
......@@ -638,7 +667,10 @@ class FusedMoE(torch.nn.Module):
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
use_nn_moe=self.use_nn_moe)
use_nn_moe=self.use_nn_moe,
moe_ep_size=self.moe_ep_size,
start_expert=self.start_expert,
end_expert=self.end_expert)
if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
......@@ -664,6 +696,33 @@ class FusedMoE(torch.nn.Module):
]
]
@classmethod
def make_expert_params_mapping_ep(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
moe_ep_size) -> List[Tuple[str, str, int, str]]:
# tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# moe_tp_rank = tp_rank // moe_ep_size
moe_ep_rank = tp_rank % moe_ep_size
experts_per_rank = num_experts // moe_ep_size
experts_range = range(moe_ep_rank * experts_per_rank,
(moe_ep_rank + 1) * experts_per_rank)
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
for expert_id in experts_range for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
......
......@@ -149,6 +149,8 @@ def _initialize_model(
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
if "parallel_config" in all_params:
kwargs["parallel_config"] = vllm_config.parallel_config
with set_current_vllm_config(vllm_config, check_compile=True):
return model_class(**kwargs)
......
......@@ -29,7 +29,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig, ParallelConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -99,6 +99,7 @@ class DeepseekV3MoE(nn.Module):
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
moe_ep_size: int = 1
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
......@@ -138,7 +139,8 @@ class DeepseekV3MoE(nn.Module):
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias)
e_score_correction_bias=self.gate.e_score_correction_bias,
moe_ep_size=moe_ep_size)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
......@@ -488,6 +490,7 @@ class DeepseekV3DecoderLayer(nn.Module):
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
moe_ep_size : int = 1,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -526,6 +529,7 @@ class DeepseekV3DecoderLayer(nn.Module):
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
moe_ep_size=moe_ep_size
)
else:
self.mlp = DeepseekV3MLP(
......@@ -575,7 +579,7 @@ class DeepseekV3Model(nn.Module):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", moe_ep_size: int = 1):
super().__init__()
config = vllm_config.model_config.hf_config
......@@ -602,6 +606,7 @@ class DeepseekV3Model(nn.Module):
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
moe_ep_size=moe_ep_size
),
prefix=f"{prefix}.layers")
......@@ -660,8 +665,12 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.parallel_config = vllm_config.parallel_config
self.moe_ep_size = self.parallel_config.moe_ep_size
self.model = DeepseekV3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
prefix=maybe_prefix(prefix, "model"),
moe_ep_size=self.moe_ep_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
......@@ -737,11 +746,19 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
if self.moe_ep_size == 1:
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
else:
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
moe_ep_size=self.moe_ep_size)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
......@@ -807,6 +824,10 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra expert weights for ep moe mode
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -75,7 +75,8 @@ class MixtralMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = ""):
prefix: str = "",
moe_ep_size: int = 1):
super().__init__()
self.hidden_size = hidden_size
......@@ -97,7 +98,8 @@ class MixtralMoE(nn.Module):
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts")
prefix=f"{prefix}.experts",
moe_ep_size=moe_ep_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
......@@ -198,6 +200,7 @@ class MixtralDecoderLayer(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
moe_ep_size : int = 1,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -218,7 +221,8 @@ class MixtralDecoderLayer(nn.Module):
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
prefix=f"{prefix}.block_sparse_moe",
moe_ep_size=moe_ep_size)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
......@@ -256,7 +260,7 @@ class MixtralDecoderLayer(nn.Module):
@support_torch_compile
class MixtralModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", moe_ep_size: int = 1):
super().__init__()
config = vllm_config.model_config.hf_config
......@@ -279,7 +283,7 @@ class MixtralModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
config, cache_config, quant_config=quant_config, prefix=prefix, moe_ep_size=moe_ep_size
),
prefix=f"{prefix}.layers")
......@@ -355,8 +359,12 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.quant_config = quant_config
self.parallel_config = vllm_config.parallel_config
self.moe_ep_size = self.parallel_config.moe_ep_size
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
prefix=maybe_prefix(prefix, "model"),
moe_ep_size=self.moe_ep_size)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......@@ -430,11 +438,19 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
if self.moe_ep_size == 1:
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)
else:
expert_params_mapping = FusedMoE.make_expert_params_mapping_ep(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts,
moe_ep_size=self.moe_ep_size)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
......@@ -486,6 +502,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip loading extra expert weights for ep moe mode
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment