Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d0dafaf4
Commit
d0dafaf4
authored
Feb 11, 2025
by
王敏
Browse files
[feat]添加ep moe功能
parent
a27fdb55
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
506 additions
and
47 deletions
+506
-47
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+30
-8
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+200
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+6
-0
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+8
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+10
-0
vllm/config.py
vllm/config.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+7
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+113
-11
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+67
-8
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+2
-0
vllm/model_executor/models/deepseek_v3.py
vllm/model_executor/models/deepseek_v3.py
+30
-9
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+31
-11
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
d0dafaf4
...
@@ -41,7 +41,8 @@ def benchmark_config(
...
@@ -41,7 +41,8 @@ def benchmark_config(
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
num_iters
:
int
=
100
,
nn_moe
:
Optional
[
bool
]
=
False
nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
int
=
1
,
)
->
float
:
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
...
@@ -140,6 +141,9 @@ def benchmark_config(
...
@@ -140,6 +141,9 @@ def benchmark_config(
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
use_nn_moe
=
nn_moe
,
use_nn_moe
=
nn_moe
,
moe_ep_size
=
moe_ep_size
,
start_expert
=
0
,
end_expert
=
num_experts
)
)
# JIT compilation & warmup
# JIT compilation & warmup
...
@@ -406,7 +410,8 @@ class BenchmarkWorker:
...
@@ -406,7 +410,8 @@ class BenchmarkWorker:
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
search_space
:
List
[
Dict
[
str
,
int
]],
search_space
:
List
[
Dict
[
str
,
int
]],
nn_moe
:
Optional
[
bool
]
=
False
nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
1
)
->
Dict
[
str
,
int
]:
)
->
Dict
[
str
,
int
]:
best_config
=
None
best_config
=
None
best_time
=
float
(
"inf"
)
best_time
=
float
(
"inf"
)
...
@@ -430,7 +435,8 @@ class BenchmarkWorker:
...
@@ -430,7 +435,8 @@ class BenchmarkWorker:
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
num_iters
=
20
,
num_iters
=
20
,
nn_moe
=
nn_moe
)
nn_moe
=
nn_moe
,
moe_ep_size
=
moe_ep_size
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
# Some configurations may be invalid and fail to compile.
continue
continue
...
@@ -520,29 +526,44 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
...
@@ -520,29 +526,44 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
print
(
args
)
moe_ep_size
=
args
.
moe_ep_size
tp_size
=
args
.
tp_size
if
moe_ep_size
>
1
:
tp_size
=
tp_size
//
moe_ep_size
config
=
AutoConfig
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
E
=
config
.
ffn_config
.
moe_num_experts
E
=
E
//
moe_ep_size
topk
=
config
.
ffn_config
.
moe_top_k
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
E
=
config
.
num_experts
E
=
E
//
moe_ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
:
elif
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
:
E
=
config
.
n_routed_experts
E
=
config
.
n_routed_experts
E
=
E
//
moe_ep_size
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
E
=
config
.
num_experts
E
=
E
//
moe_ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
else
:
else
:
# Default: Mixtral.
# Default: Mixtral.
E
=
config
.
num_local_experts
E
=
config
.
num_local_experts
E
=
E
//
moe_ep_size
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
...
@@ -582,7 +603,7 @@ def main(args: argparse.Namespace):
...
@@ -582,7 +603,7 @@ def main(args: argparse.Namespace):
start
=
time
.
time
()
start
=
time
.
time
()
configs
=
_distribute
(
configs
=
_distribute
(
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
,
args
.
nn_moe
)
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
,
args
.
nn_moe
,
moe_ep_size
)
for
batch_size
in
batch_sizes
])
for
batch_size
in
batch_sizes
])
best_configs
=
{
best_configs
=
{
M
:
sort_config
(
config
)
M
:
sort_config
(
config
)
...
@@ -622,6 +643,7 @@ if __name__ == "__main__":
...
@@ -622,6 +643,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--nn_moe"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--nn_moe"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--moe-ep-size"
,
type
=
int
,
default
=
1
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
csrc/moe/moe_align_sum_kernels.cu
View file @
d0dafaf4
...
@@ -279,6 +279,103 @@ __global__ void moe_sum_kernel(
...
@@ -279,6 +279,103 @@ __global__ void moe_sum_kernel(
}
}
}
}
template
<
typename
scalar_t
,
typename
token_cnts_t
>
__global__
void
ep_moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
start_expert
,
int32_t
end_expert
)
{
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
cumsum
=
shared_mem
;
// 1d tensor with shape (num_experts + 1)
token_cnts_t
*
tokens_cnts
=
(
token_cnts_t
*
)(
shared_mem
+
num_experts
+
1
);
// 2d tensor with shape (blockDim.x + 1, num_experts)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
if
(
topk_ids
[
i
]
>=
start_expert
&&
topk_ids
[
i
]
<
end_expert
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
]
-
start_expert
)];
}
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
static_cast
<
int32_t
>
(
cumsum
[
num_experts
]);
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
if
(
expert_id
>=
start_expert
&&
expert_id
<
end_expert
)
{
expert_id
-=
start_expert
;
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
;
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
}
}
// namespace moe
}
// namespace moe
}
// namespace vllm
}
// namespace vllm
...
@@ -371,6 +468,109 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -371,6 +468,109 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
}
}
}
void
ep_moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
int64_t
start_expert
,
int64_t
end_expert
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int
device_max_shared_mem
;
auto
dev
=
topk_ids
.
get_device
();
cudaDeviceGetAttribute
(
&
device_max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
const
int32_t
shared_mem_i32
=
((
num_thread
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
sizeof
(
int32_t
);
const
int32_t
shared_mem_i16
=
((
num_thread
+
1
)
*
num_experts
)
*
sizeof
(
uint16_t
)
+
(
num_experts
+
1
)
*
sizeof
(
int32_t
);
// bool use_global_memory = false;
// bool use_i16 = false; // Use uint16_t for shared memory token counts
// if (shared_mem_i32 < device_max_shared_mem) {
// // Do nothing in this case. We're all set to use int32_t token counts
// } else if (shared_mem_i16 < device_max_shared_mem &&
// topk_ids.numel() <= 65535) {
// // when nelements of topk_ids is smaller than 65535 (max value of uint16),
// // element value of token_cnts would also smaller than 65535,
// // so we can use uint16 as dtype of token_cnts
// use_i16 = true;
// } else {
// use_global_memory = true;
// }
// if (use_global_memory) {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// // tensors
// const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
// auto options_int = torch::TensorOptions()
// .dtype(torch::kInt)
// .device(topk_ids.device());
// torch::Tensor token_cnts_buffer =
// torch::empty({(num_experts + 1) * num_experts}, options_int);
// torch::Tensor cumsum_buffer =
// torch::empty({num_experts + 1}, options_int);
// auto kernel =
// vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
// kernel<<<1, num_thread, 0, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
// cumsum_buffer.data_ptr<int32_t>());
// });
// } else if (use_i16) {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// // set dynamic shared mem
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i16));
// kernel<<<1, num_thread, shared_mem_i16, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// } else {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i32));
// kernel<<<1, num_thread, shared_mem_i32, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// }
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"ep_moe_align_block_size_kernel"
,
[
&
]
{
auto
kernel
=
vllm
::
moe
::
ep_moe_align_block_size_kernel
<
scalar_t
,
int32_t
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem_i32
));
kernel
<<<
1
,
num_thread
,
shared_mem_i32
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
start_expert
,
end_expert
);
});
}
void
sgl_moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
void
sgl_moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
sorted_token_ids
,
...
...
csrc/moe/moe_ops.h
View file @
d0dafaf4
...
@@ -13,6 +13,12 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -13,6 +13,12 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
torch
::
Tensor
num_tokens_post_pad
);
void
ep_moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
int64_t
start_expert
,
int64_t
end_expert
);
void
sgl_moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
void
sgl_moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
sorted_token_ids
,
...
...
csrc/moe/torch_bindings.cpp
View file @
d0dafaf4
...
@@ -22,6 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -22,6 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"
);
" Tensor! num_tokens_post_pad) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
ops
.
def
(
"ep_moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad,"
" int start_expert, int end_expert) -> ()"
);
ops
.
impl
(
"ep_moe_align_block_size"
,
torch
::
kCUDA
,
&
ep_moe_align_block_size
);
// temporarily adapted from
// temporarily adapted from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
m
.
def
(
m
.
def
(
...
...
vllm/_custom_ops.py
View file @
d0dafaf4
...
@@ -1378,6 +1378,16 @@ def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
...
@@ -1378,6 +1378,16 @@ def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
torch
.
ops
.
_moe_C
.
sgl_moe_align_block_size
(
topk_ids
,
num_experts
,
torch
.
ops
.
_moe_C
.
sgl_moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
)
experts_ids
,
num_tokens_post_pad
)
def
ep_moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
experts_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
start_expert
,
end_expert
)
->
None
:
torch
.
ops
.
_C
.
ep_moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
,
start_expert
,
end_expert
)
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
...
...
vllm/config.py
View file @
d0dafaf4
...
@@ -1342,6 +1342,8 @@ class ParallelConfig:
...
@@ -1342,6 +1342,8 @@ class ParallelConfig:
rank
:
int
=
0
rank
:
int
=
0
moe_ep_size
:
Optional
[
int
]
=
1
def
compute_hash
(
self
):
def
compute_hash
(
self
):
"""
"""
Provide a hash that uniquely identifies all the configs
Provide a hash that uniquely identifies all the configs
...
...
vllm/engine/arg_utils.py
View file @
d0dafaf4
...
@@ -206,6 +206,8 @@ class EngineArgs:
...
@@ -206,6 +206,8 @@ class EngineArgs:
calculate_kv_scales
:
Optional
[
bool
]
=
None
calculate_kv_scales
:
Optional
[
bool
]
=
None
moe_ep_size
:
int
=
1
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
not
self
.
tokenizer
:
if
not
self
.
tokenizer
:
self
.
tokenizer
=
self
.
model
self
.
tokenizer
=
self
.
model
...
@@ -417,6 +419,10 @@ class EngineArgs:
...
@@ -417,6 +419,10 @@ class EngineArgs:
type
=
int
,
type
=
int
,
default
=
EngineArgs
.
tensor_parallel_size
,
default
=
EngineArgs
.
tensor_parallel_size
,
help
=
'Number of tensor parallel replicas.'
)
help
=
'Number of tensor parallel replicas.'
)
parser
.
add_argument
(
'--moe-ep-size'
,
type
=
int
,
default
=
EngineArgs
.
moe_ep_size
,
help
=
'Number of moe expert parallel replicas.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--max-parallel-loading-workers'
,
'--max-parallel-loading-workers'
,
type
=
int
,
type
=
int
,
...
@@ -1123,6 +1129,7 @@ class EngineArgs:
...
@@ -1123,6 +1129,7 @@ class EngineArgs:
ray_workers_use_nsight
=
self
.
ray_workers_use_nsight
,
ray_workers_use_nsight
=
self
.
ray_workers_use_nsight
,
distributed_executor_backend
=
self
.
distributed_executor_backend
,
distributed_executor_backend
=
self
.
distributed_executor_backend
,
worker_cls
=
self
.
worker_cls
,
worker_cls
=
self
.
worker_cls
,
moe_ep_size
=
self
.
moe_ep_size
,
)
)
max_model_len
=
model_config
.
max_model_len
max_model_len
=
model_config
.
max_model_len
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
d0dafaf4
...
@@ -633,6 +633,66 @@ def moe_align_block_size(
...
@@ -633,6 +633,66 @@ def moe_align_block_size(
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
moe_ep_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
start_expert
:
int
,
end_expert
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
ep_moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
start_expert
,
end_expert
)
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
...
@@ -1029,11 +1089,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1029,11 +1089,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
start_expert
=
start_expert
,
end_expert
=
end_expert
)
def
inplace_fused_experts_fake
(
def
inplace_fused_experts_fake
(
...
@@ -1052,7 +1116,10 @@ def inplace_fused_experts_fake(
...
@@ -1052,7 +1116,10 @@ def inplace_fused_experts_fake(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
)
->
None
:
pass
pass
...
@@ -1080,12 +1147,16 @@ def outplace_fused_experts(
...
@@ -1080,12 +1147,16 @@ def outplace_fused_experts(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
start_expert
=
start_expert
,
end_expert
=
end_expert
)
def
outplace_fused_experts_fake
(
def
outplace_fused_experts_fake
(
...
@@ -1104,7 +1175,10 @@ def outplace_fused_experts_fake(
...
@@ -1104,7 +1175,10 @@ def outplace_fused_experts_fake(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -1132,7 +1206,10 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1132,7 +1206,10 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
):
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
):
if
inplace
:
if
inplace
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
topk_weights
,
topk_ids
,
...
@@ -1140,14 +1217,19 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1140,14 +1217,19 @@ def fused_experts(hidden_states: torch.Tensor,
use_int4_w4a16
,
w1_scale
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
a2_scale
,
block_shape
,
use_nn_moe
)
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
start_expert
=
start_expert
,
end_expert
=
end_expert
)
return
hidden_states
return
hidden_states
else
:
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
start_expert
=
start_expert
,
end_expert
=
end_expert
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
...
@@ -1166,7 +1248,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1166,7 +1248,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
):
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
):
# Check constraints.
# Check constraints.
if
use_int4_w4a16
:
if
use_int4_w4a16
:
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
...
@@ -1219,6 +1304,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1219,6 +1304,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache3
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
intermediate_cache3
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
if
moe_ep_size
>
1
:
intermediate_cache3
.
zero_
()
if
hidden_states
.
dtype
==
torch
.
bfloat16
:
if
hidden_states
.
dtype
==
torch
.
bfloat16
:
compute_type
=
tl
.
bfloat16
compute_type
=
tl
.
bfloat16
...
@@ -1259,6 +1347,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1259,6 +1347,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
if
moe_ep_size
==
1
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
else
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_ep_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
,
start_expert
,
end_expert
))
invoke_fused_moe_kernel
(
curr_hidden_states
,
invoke_fused_moe_kernel
(
curr_hidden_states
,
w1
,
w1
,
...
@@ -1333,6 +1429,9 @@ def fused_moe(
...
@@ -1333,6 +1429,9 @@ def fused_moe(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
None
,
start_expert
:
Optional
[
int
]
=
None
,
end_expert
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
This function computes a Mixture of Experts (MoE) layer using two sets of
...
@@ -1405,4 +1504,7 @@ def fused_moe(
...
@@ -1405,4 +1504,7 @@ def fused_moe(
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
start_expert
=
start_expert
,
end_expert
=
end_expert
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
d0dafaf4
...
@@ -58,7 +58,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -58,7 +58,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -134,7 +134,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -134,7 +134,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
layer
=
layer
,
...
@@ -147,7 +150,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -147,7 +150,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
start_expert
=
start_expert
,
end_expert
=
end_expert
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
...
@@ -162,7 +168,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -162,7 +168,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -182,7 +191,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -182,7 +191,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
start_expert
=
start_expert
,
end_expert
=
end_expert
)
def
forward_cpu
(
def
forward_cpu
(
self
,
self
,
...
@@ -221,7 +233,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -221,7 +233,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
not
use_grouped_topk
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
num_expert_group
is
None
...
@@ -282,6 +297,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -282,6 +297,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
moe_ep_size
:
Optional
[
int
]
=
1
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -305,6 +321,17 @@ class FusedMoE(torch.nn.Module):
...
@@ -305,6 +321,17 @@ class FusedMoE(torch.nn.Module):
self
.
scoring_func
=
scoring_func
self
.
scoring_func
=
scoring_func
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
moe_ep_size
=
moe_ep_size
self
.
moe_tp_rank
=
self
.
tp_rank
//
self
.
moe_ep_size
self
.
moe_tp_size
=
self
.
tp_size
//
self
.
moe_ep_size
if
self
.
moe_ep_size
>
1
:
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
moe_tp_size
self
.
moe_ep_rank
=
self
.
tp_rank
%
self
.
moe_ep_size
num_experts_per_node
=
num_experts
//
self
.
moe_ep_size
self
.
start_expert
=
num_experts_per_node
*
self
.
moe_ep_rank
self
.
end_expert
=
self
.
start_expert
+
num_experts_per_node
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
raise
ValueError
(
"Only softmax scoring function is supported for "
raise
ValueError
(
"Only softmax scoring function is supported for "
"non-grouped topk."
)
"non-grouped topk."
)
...
@@ -323,7 +350,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -323,7 +350,7 @@ class FusedMoE(torch.nn.Module):
self
.
use_nn_moe
=
False
self
.
use_nn_moe
=
False
moe_quant_params
=
{
moe_quant_params
=
{
"num_experts"
:
num_experts
,
"num_experts"
:
num_experts
if
self
.
moe_ep_size
==
1
else
num_experts_per_node
,
"hidden_size"
:
hidden_size
,
"hidden_size"
:
hidden_size
,
"intermediate_size_per_partition"
:
"intermediate_size_per_partition"
:
self
.
intermediate_size_per_partition
,
self
.
intermediate_size_per_partition
,
...
@@ -489,8 +516,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -489,8 +516,10 @@ class FusedMoE(torch.nn.Module):
# dimension intermediate_size_per_partition is used.
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM
=
{
"w1"
:
0
,
"w2"
:
1
,
"w3"
:
0
}
SHARD_ID_TO_SHARDED_DIM
=
{
"w1"
:
0
,
"w2"
:
1
,
"w3"
:
0
}
expert_id
=
expert_id
-
self
.
start_expert
expert_data
=
param
.
data
[
expert_id
]
expert_data
=
param
.
data
[
expert_id
]
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
tp_rank
//
self
.
moe_ep_size
# is_transposed: if the dim to shard the weight
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be flipped. Required by GPTQ, compressed-tensors
...
@@ -638,7 +667,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -638,7 +667,10 @@ class FusedMoE(torch.nn.Module):
custom_routing_function
=
self
.
custom_routing_function
,
custom_routing_function
=
self
.
custom_routing_function
,
scoring_func
=
self
.
scoring_func
,
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
use_nn_moe
=
self
.
use_nn_moe
)
use_nn_moe
=
self
.
use_nn_moe
,
moe_ep_size
=
self
.
moe_ep_size
,
start_expert
=
self
.
start_expert
,
end_expert
=
self
.
end_expert
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
...
@@ -663,6 +695,33 @@ class FusedMoE(torch.nn.Module):
...
@@ -663,6 +695,33 @@ class FusedMoE(torch.nn.Module):
(
"w3"
,
ckpt_up_proj_name
),
(
"w3"
,
ckpt_up_proj_name
),
]
]
]
]
@
classmethod
def
make_expert_params_mapping_ep
(
cls
,
ckpt_gate_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_up_proj_name
:
str
,
num_experts
:
int
,
moe_ep_size
)
->
List
[
Tuple
[
str
,
str
,
int
,
str
]]:
# tp_size = get_tensor_model_parallel_world_size()
tp_rank
=
get_tensor_model_parallel_rank
()
# moe_tp_rank = tp_rank // moe_ep_size
moe_ep_rank
=
tp_rank
%
moe_ep_size
experts_per_rank
=
num_experts
//
moe_ep_size
experts_range
=
range
(
moe_ep_rank
*
experts_per_rank
,
(
moe_ep_rank
+
1
)
*
experts_per_rank
)
return
[
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_"
if
weight_name
in
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
else
"experts.w2_"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
."
,
expert_id
,
shard_id
)
for
expert_id
in
experts_range
for
shard_id
,
weight_name
in
[
(
"w1"
,
ckpt_gate_proj_name
),
(
"w2"
,
ckpt_down_proj_name
),
(
"w3"
,
ckpt_up_proj_name
),
]
]
def
_load_fp8_scale
(
self
,
param
:
torch
.
nn
.
Parameter
,
def
_load_fp8_scale
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
d0dafaf4
...
@@ -149,6 +149,8 @@ def _initialize_model(
...
@@ -149,6 +149,8 @@ def _initialize_model(
kwargs
[
"lora_config"
]
=
vllm_config
.
lora_config
kwargs
[
"lora_config"
]
=
vllm_config
.
lora_config
if
"scheduler_config"
in
all_params
:
if
"scheduler_config"
in
all_params
:
kwargs
[
"scheduler_config"
]
=
vllm_config
.
scheduler_config
kwargs
[
"scheduler_config"
]
=
vllm_config
.
scheduler_config
if
"parallel_config"
in
all_params
:
kwargs
[
"parallel_config"
]
=
vllm_config
.
parallel_config
with
set_current_vllm_config
(
vllm_config
,
check_compile
=
True
):
with
set_current_vllm_config
(
vllm_config
,
check_compile
=
True
):
return
model_class
(
**
kwargs
)
return
model_class
(
**
kwargs
)
...
...
vllm/model_executor/models/deepseek_v3.py
View file @
d0dafaf4
...
@@ -29,7 +29,7 @@ from torch import nn
...
@@ -29,7 +29,7 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
,
ParallelConfig
from
vllm.distributed
import
(
get_pp_group
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
...
@@ -99,6 +99,7 @@ class DeepseekV3MoE(nn.Module):
...
@@ -99,6 +99,7 @@ class DeepseekV3MoE(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
moe_ep_size
:
int
=
1
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -138,7 +139,8 @@ class DeepseekV3MoE(nn.Module):
...
@@ -138,7 +139,8 @@ class DeepseekV3MoE(nn.Module):
topk_group
=
config
.
topk_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
)
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
moe_ep_size
=
moe_ep_size
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
intermediate_size
=
(
config
.
moe_intermediate_size
*
...
@@ -488,6 +490,7 @@ class DeepseekV3DecoderLayer(nn.Module):
...
@@ -488,6 +490,7 @@ class DeepseekV3DecoderLayer(nn.Module):
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
moe_ep_size
:
int
=
1
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -526,6 +529,7 @@ class DeepseekV3DecoderLayer(nn.Module):
...
@@ -526,6 +529,7 @@ class DeepseekV3DecoderLayer(nn.Module):
config
=
config
,
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
prefix
=
f
"
{
prefix
}
.mlp"
,
moe_ep_size
=
moe_ep_size
)
)
else
:
else
:
self
.
mlp
=
DeepseekV3MLP
(
self
.
mlp
=
DeepseekV3MLP
(
...
@@ -575,7 +579,7 @@ class DeepseekV3Model(nn.Module):
...
@@ -575,7 +579,7 @@ class DeepseekV3Model(nn.Module):
fall_back_to_pt_during_load
=
False
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
moe_ep_size
:
int
=
1
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -602,6 +606,7 @@ class DeepseekV3Model(nn.Module):
...
@@ -602,6 +606,7 @@ class DeepseekV3Model(nn.Module):
model_config
=
model_config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
moe_ep_size
=
moe_ep_size
),
),
prefix
=
f
"
{
prefix
}
.layers"
)
prefix
=
f
"
{
prefix
}
.layers"
)
...
@@ -660,8 +665,12 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
...
@@ -660,8 +665,12 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
moe_ep_size
=
self
.
parallel_config
.
moe_ep_size
self
.
model
=
DeepseekV3Model
(
vllm_config
=
vllm_config
,
self
.
model
=
DeepseekV3Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
),
moe_ep_size
=
self
.
moe_ep_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
...
@@ -737,11 +746,19 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
...
@@ -737,11 +746,19 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
if
self
.
moe_ep_size
==
1
:
ckpt_gate_proj_name
=
"gate_proj"
,
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_down_proj_name
=
"down_proj"
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
else
:
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
,
moe_ep_size
=
self
.
moe_ep_size
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
...
@@ -806,6 +823,10 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
...
@@ -806,6 +823,10 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
# Skip loading extra expert weights for ep moe mode
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
...
...
vllm/model_executor/models/mixtral.py
View file @
d0dafaf4
...
@@ -75,7 +75,8 @@ class MixtralMoE(nn.Module):
...
@@ -75,7 +75,8 @@ class MixtralMoE(nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
moe_ep_size
:
int
=
1
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -97,7 +98,8 @@ class MixtralMoE(nn.Module):
...
@@ -97,7 +98,8 @@ class MixtralMoE(nn.Module):
renormalize
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
prefix
=
f
"
{
prefix
}
.experts"
)
prefix
=
f
"
{
prefix
}
.experts"
,
moe_ep_size
=
moe_ep_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
# NOTE: hidden_states can have either 1D or 2D shape.
...
@@ -198,6 +200,7 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -198,6 +200,7 @@ class MixtralDecoderLayer(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
moe_ep_size
:
int
=
1
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -218,7 +221,8 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -218,7 +221,8 @@ class MixtralDecoderLayer(nn.Module):
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
)
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
,
moe_ep_size
=
moe_ep_size
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -256,7 +260,7 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -256,7 +260,7 @@ class MixtralDecoderLayer(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
MixtralModel
(
nn
.
Module
):
class
MixtralModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
moe_ep_size
:
int
=
1
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -279,7 +283,7 @@ class MixtralModel(nn.Module):
...
@@ -279,7 +283,7 @@ class MixtralModel(nn.Module):
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
MixtralDecoderLayer
(
lambda
prefix
:
MixtralDecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
moe_ep_size
=
moe_ep_size
),
),
prefix
=
f
"
{
prefix
}
.layers"
)
prefix
=
f
"
{
prefix
}
.layers"
)
...
@@ -355,8 +359,12 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -355,8 +359,12 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
moe_ep_size
=
self
.
parallel_config
.
moe_ep_size
self
.
model
=
MixtralModel
(
vllm_config
=
vllm_config
,
self
.
model
=
MixtralModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
),
moe_ep_size
=
self
.
moe_ep_size
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
@@ -430,11 +438,19 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -430,11 +438,19 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
if
self
.
moe_ep_size
==
1
:
ckpt_gate_proj_name
=
"w1"
,
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_down_proj_name
=
"w2"
,
ckpt_gate_proj_name
=
"w1"
,
ckpt_up_proj_name
=
"w3"
,
ckpt_down_proj_name
=
"w2"
,
num_experts
=
self
.
config
.
num_local_experts
)
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_local_experts
)
else
:
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping_ep
(
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_local_experts
,
moe_ep_size
=
self
.
moe_ep_size
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
...
@@ -486,6 +502,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -486,6 +502,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
and
name
not
in
params_dict
):
continue
continue
# Skip loading extra expert weights for ep moe mode
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
param
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment