Unverified Commit 05339a7b authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[Bugfix][CPU] Fix llama4 inference on CPU (#34321)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent 40b8f553
......@@ -238,3 +238,6 @@ ep_kernels_workspace/
vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2.pyi
# Ignore generated cpu headers
csrc/cpu/cpu_attn_dispatch_generated.h
......@@ -147,7 +147,7 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
const int32_t token_num, const int32_t expert_num,
const int32_t topk_num, const int32_t input_size_13,
const int32_t output_size_13, const int32_t input_size_2,
const int32_t output_size_2) {
const int32_t output_size_2, const bool skip_weighted) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
......@@ -582,6 +582,11 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
scalar_t* __restrict__ curr_output_buffer =
output + token_id * output_size_2;
if (skip_weighted) {
// Only for topk_num == 1
*curr_weight = 1.0f;
}
if (topk_num > 1) {
{
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
......@@ -699,7 +704,7 @@ void cpu_fused_moe(
const std::optional<torch::Tensor>& w2_bias, // [expert_num, output_size_2]
const torch::Tensor& topk_weights, // [token_num, k], float32
const torch::Tensor& topk_id, // [token_num, k], int32
const std::string& act, const std::string& isa) {
const bool skip_weighted, const std::string& act, const std::string& isa) {
const int32_t token_num = input.size(0);
const int32_t input_size_13 = input.size(1);
const int64_t input_stride = input.stride(0);
......@@ -711,6 +716,8 @@ void cpu_fused_moe(
const int32_t topk_num = topk_id.size(1);
const FusedMOEAct act_type = get_act_type(act);
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
TORCH_CHECK(!skip_weighted || topk_num == 1,
"skip_weighted is only supported for topk=1 on CPU");
VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() {
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
......@@ -721,7 +728,7 @@ void cpu_fused_moe(
w2_bias.has_value() ? w2_bias->data_ptr<scalar_t>() : nullptr,
topk_weights.data_ptr<float>(), topk_id.data_ptr<int32_t>(), act_type,
token_num, expert_num, topk_num, input_size_13, output_size_13,
input_size_2, output_size_2);
input_size_2, output_size_2, skip_weighted);
});
});
}
......@@ -119,8 +119,8 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const std::optional<torch::Tensor>& w13_bias,
const std::optional<torch::Tensor>& w2_bias,
const torch::Tensor& topk_weights,
const torch::Tensor& topk_id, const std::string& act,
const std::string& isa);
const torch::Tensor& topk_id, const bool skip_weighted,
const std::string& act, const std::string& isa);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
......@@ -320,6 +320,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"bool skip_weighted, "
"str act, str isa) -> ()");
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif
......
......@@ -3078,6 +3078,7 @@ def cpu_fused_moe(
topk_ids: torch.Tensor,
act: str,
isa: str,
skip_weighted: bool = False,
) -> torch.Tensor:
output = torch.empty_like(input)
torch.ops._C.cpu_fused_moe(
......@@ -3089,6 +3090,7 @@ def cpu_fused_moe(
w2_bias,
topk_weights,
topk_ids,
skip_weighted,
act,
isa,
)
......
......@@ -238,7 +238,6 @@ class CPUFusedMOE:
activation: str = "silu",
) -> torch.Tensor:
assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
......@@ -261,6 +260,7 @@ class CPUFusedMOE:
topk_ids,
activation,
global_num_experts,
apply_router_weight_on_input,
)
def check_grouped_gemm(
......@@ -355,7 +355,14 @@ class CPUFusedMOE:
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
skip_weighted: bool = False,
) -> torch.Tensor:
if skip_weighted:
assert topk_ids.size(1) == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
input.mul_(topk_weights.to(input.dtype))
output = cpu_fused_moe(
input,
layer.w13_weight,
......@@ -366,6 +373,7 @@ class CPUFusedMOE:
topk_ids,
activation,
self.isa,
skip_weighted,
)
return output
......@@ -377,7 +385,14 @@ class CPUFusedMOE:
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
skip_weighted: bool = False,
) -> torch.Tensor:
if skip_weighted:
assert topk_ids.size(1) == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
input.mul_(topk_weights.to(input.dtype))
output = torch.empty_like(input)
layer_id = id(layer)
torch.ops.vllm.cpu_fused_moe_torch(
......@@ -388,6 +403,7 @@ class CPUFusedMOE:
topk_ids,
activation,
global_num_experts,
skip_weighted,
)
return output
......@@ -401,6 +417,7 @@ def cpu_fused_moe_torch(
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
skip_weighted: bool = False,
) -> None:
layer = _CPU_MOE_LAYER_CACHE[layer_id]()
......@@ -434,13 +451,16 @@ def cpu_fused_moe_torch(
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weights.dtype)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
if skip_weighted:
final_out = new_x
else:
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weights.dtype)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
output.copy_(final_out)
......
......@@ -160,12 +160,21 @@ class CPUWorker(Worker):
x for x in logical_cpu_list if x.numa_node == selected_numa_node
]
else:
assert len(logical_cpu_list) >= self.parallel_config.world_size
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node)
sim_cpu_num_per_node = (
len(logical_cpu_list) // self.parallel_config.world_size
# This is a bit tricky because the internal DP size
# is always 1 for non-MoE models
world_size_across_dp = (
self.parallel_config.world_size
* self.parallel_config._api_process_count
)
start_idx = self.local_rank * sim_cpu_num_per_node
assert len(logical_cpu_list) >= world_size_across_dp
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node)
sim_cpu_num_per_node = len(logical_cpu_list) // world_size_across_dp
assert self.parallel_config.data_parallel_rank_local is not None
start_idx = (
self.local_rank
+ self.parallel_config.world_size
* self.parallel_config.data_parallel_rank_local
) * sim_cpu_num_per_node
logical_cpu_list = logical_cpu_list[
start_idx : (start_idx + sim_cpu_num_per_node)
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment