Unverified Commit 05339a7b authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[Bugfix][CPU] Fix llama4 inference on CPU (#34321)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent 40b8f553
...@@ -238,3 +238,6 @@ ep_kernels_workspace/ ...@@ -238,3 +238,6 @@ ep_kernels_workspace/
vllm/grpc/vllm_engine_pb2.py vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2_grpc.py vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2.pyi vllm/grpc/vllm_engine_pb2.pyi
# Ignore generated cpu headers
csrc/cpu/cpu_attn_dispatch_generated.h
...@@ -147,7 +147,7 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input, ...@@ -147,7 +147,7 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
const int32_t token_num, const int32_t expert_num, const int32_t token_num, const int32_t expert_num,
const int32_t topk_num, const int32_t input_size_13, const int32_t topk_num, const int32_t input_size_13,
const int32_t output_size_13, const int32_t input_size_2, const int32_t output_size_13, const int32_t input_size_2,
const int32_t output_size_2) { const int32_t output_size_2, const bool skip_weighted) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t; using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
constexpr int32_t gemm_n_tile_size = gemm_t::NSize; constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize; constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
...@@ -582,6 +582,11 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input, ...@@ -582,6 +582,11 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
scalar_t* __restrict__ curr_output_buffer = scalar_t* __restrict__ curr_output_buffer =
output + token_id * output_size_2; output + token_id * output_size_2;
if (skip_weighted) {
// Only for topk_num == 1
*curr_weight = 1.0f;
}
if (topk_num > 1) { if (topk_num > 1) {
{ {
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0]; int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
...@@ -699,7 +704,7 @@ void cpu_fused_moe( ...@@ -699,7 +704,7 @@ void cpu_fused_moe(
const std::optional<torch::Tensor>& w2_bias, // [expert_num, output_size_2] const std::optional<torch::Tensor>& w2_bias, // [expert_num, output_size_2]
const torch::Tensor& topk_weights, // [token_num, k], float32 const torch::Tensor& topk_weights, // [token_num, k], float32
const torch::Tensor& topk_id, // [token_num, k], int32 const torch::Tensor& topk_id, // [token_num, k], int32
const std::string& act, const std::string& isa) { const bool skip_weighted, const std::string& act, const std::string& isa) {
const int32_t token_num = input.size(0); const int32_t token_num = input.size(0);
const int32_t input_size_13 = input.size(1); const int32_t input_size_13 = input.size(1);
const int64_t input_stride = input.stride(0); const int64_t input_stride = input.stride(0);
...@@ -711,6 +716,8 @@ void cpu_fused_moe( ...@@ -711,6 +716,8 @@ void cpu_fused_moe(
const int32_t topk_num = topk_id.size(1); const int32_t topk_num = topk_id.size(1);
const FusedMOEAct act_type = get_act_type(act); const FusedMOEAct act_type = get_act_type(act);
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa); cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
TORCH_CHECK(!skip_weighted || topk_num == 1,
"skip_weighted is only supported for topk=1 on CPU");
VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() { VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() {
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() { CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
...@@ -721,7 +728,7 @@ void cpu_fused_moe( ...@@ -721,7 +728,7 @@ void cpu_fused_moe(
w2_bias.has_value() ? w2_bias->data_ptr<scalar_t>() : nullptr, w2_bias.has_value() ? w2_bias->data_ptr<scalar_t>() : nullptr,
topk_weights.data_ptr<float>(), topk_id.data_ptr<int32_t>(), act_type, topk_weights.data_ptr<float>(), topk_id.data_ptr<int32_t>(), act_type,
token_num, expert_num, topk_num, input_size_13, output_size_13, token_num, expert_num, topk_num, input_size_13, output_size_13,
input_size_2, output_size_2); input_size_2, output_size_2, skip_weighted);
}); });
}); });
} }
...@@ -119,8 +119,8 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input, ...@@ -119,8 +119,8 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const std::optional<torch::Tensor>& w13_bias, const std::optional<torch::Tensor>& w13_bias,
const std::optional<torch::Tensor>& w2_bias, const std::optional<torch::Tensor>& w2_bias,
const torch::Tensor& topk_weights, const torch::Tensor& topk_weights,
const torch::Tensor& topk_id, const std::string& act, const torch::Tensor& topk_id, const bool skip_weighted,
const std::string& isa); const std::string& act, const std::string& isa);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
...@@ -320,6 +320,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -320,6 +320,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def( ops.def(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, " "cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, " "Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"bool skip_weighted, "
"str act, str isa) -> ()"); "str act, str isa) -> ()");
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe); ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif #endif
......
...@@ -3078,6 +3078,7 @@ def cpu_fused_moe( ...@@ -3078,6 +3078,7 @@ def cpu_fused_moe(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
act: str, act: str,
isa: str, isa: str,
skip_weighted: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty_like(input) output = torch.empty_like(input)
torch.ops._C.cpu_fused_moe( torch.ops._C.cpu_fused_moe(
...@@ -3089,6 +3090,7 @@ def cpu_fused_moe( ...@@ -3089,6 +3090,7 @@ def cpu_fused_moe(
w2_bias, w2_bias,
topk_weights, topk_weights,
topk_ids, topk_ids,
skip_weighted,
act, act,
isa, isa,
) )
......
...@@ -238,7 +238,6 @@ class CPUFusedMOE: ...@@ -238,7 +238,6 @@ class CPUFusedMOE:
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported." assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
...@@ -261,6 +260,7 @@ class CPUFusedMOE: ...@@ -261,6 +260,7 @@ class CPUFusedMOE:
topk_ids, topk_ids,
activation, activation,
global_num_experts, global_num_experts,
apply_router_weight_on_input,
) )
def check_grouped_gemm( def check_grouped_gemm(
...@@ -355,7 +355,14 @@ class CPUFusedMOE: ...@@ -355,7 +355,14 @@ class CPUFusedMOE:
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: str,
global_num_experts: int = -1, global_num_experts: int = -1,
skip_weighted: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if skip_weighted:
assert topk_ids.size(1) == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
input.mul_(topk_weights.to(input.dtype))
output = cpu_fused_moe( output = cpu_fused_moe(
input, input,
layer.w13_weight, layer.w13_weight,
...@@ -366,6 +373,7 @@ class CPUFusedMOE: ...@@ -366,6 +373,7 @@ class CPUFusedMOE:
topk_ids, topk_ids,
activation, activation,
self.isa, self.isa,
skip_weighted,
) )
return output return output
...@@ -377,7 +385,14 @@ class CPUFusedMOE: ...@@ -377,7 +385,14 @@ class CPUFusedMOE:
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: str,
global_num_experts: int = -1, global_num_experts: int = -1,
skip_weighted: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if skip_weighted:
assert topk_ids.size(1) == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
input.mul_(topk_weights.to(input.dtype))
output = torch.empty_like(input) output = torch.empty_like(input)
layer_id = id(layer) layer_id = id(layer)
torch.ops.vllm.cpu_fused_moe_torch( torch.ops.vllm.cpu_fused_moe_torch(
...@@ -388,6 +403,7 @@ class CPUFusedMOE: ...@@ -388,6 +403,7 @@ class CPUFusedMOE:
topk_ids, topk_ids,
activation, activation,
global_num_experts, global_num_experts,
skip_weighted,
) )
return output return output
...@@ -401,6 +417,7 @@ def cpu_fused_moe_torch( ...@@ -401,6 +417,7 @@ def cpu_fused_moe_torch(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: str,
global_num_experts: int = -1, global_num_experts: int = -1,
skip_weighted: bool = False,
) -> None: ) -> None:
layer = _CPU_MOE_LAYER_CACHE[layer_id]() layer = _CPU_MOE_LAYER_CACHE[layer_id]()
...@@ -434,13 +451,16 @@ def cpu_fused_moe_torch( ...@@ -434,13 +451,16 @@ def cpu_fused_moe_torch(
new_x = torch.empty_like(outs) new_x = torch.empty_like(outs)
new_x[idxs] = outs new_x[idxs] = outs
final_out = ( if skip_weighted:
new_x.view(*topk_ids.shape, -1) final_out = new_x
.type(topk_weights.dtype) else:
.mul_(topk_weights.unsqueeze(dim=-1)) final_out = (
.sum(dim=1) new_x.view(*topk_ids.shape, -1)
.type(new_x.dtype) .type(topk_weights.dtype)
) .mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
output.copy_(final_out) output.copy_(final_out)
......
...@@ -160,12 +160,21 @@ class CPUWorker(Worker): ...@@ -160,12 +160,21 @@ class CPUWorker(Worker):
x for x in logical_cpu_list if x.numa_node == selected_numa_node x for x in logical_cpu_list if x.numa_node == selected_numa_node
] ]
else: else:
assert len(logical_cpu_list) >= self.parallel_config.world_size # This is a bit tricky because the internal DP size
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node) # is always 1 for non-MoE models
sim_cpu_num_per_node = ( world_size_across_dp = (
len(logical_cpu_list) // self.parallel_config.world_size self.parallel_config.world_size
* self.parallel_config._api_process_count
) )
start_idx = self.local_rank * sim_cpu_num_per_node assert len(logical_cpu_list) >= world_size_across_dp
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node)
sim_cpu_num_per_node = len(logical_cpu_list) // world_size_across_dp
assert self.parallel_config.data_parallel_rank_local is not None
start_idx = (
self.local_rank
+ self.parallel_config.world_size
* self.parallel_config.data_parallel_rank_local
) * sim_cpu_num_per_node
logical_cpu_list = logical_cpu_list[ logical_cpu_list = logical_cpu_list[
start_idx : (start_idx + sim_cpu_num_per_node) start_idx : (start_idx + sim_cpu_num_per_node)
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment