"git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "4ac2c5f2cc8a8b1f221f1e8e9b7839f07c25d997"
Unverified Commit f0d4e145 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add fused top-K softmax kernel for MoE (#2769)

parent 2ccee3de
#include "moe_ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
}
#pragma once
#include <torch/extension.h>
void topk_softmax(
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
This diff is collapsed.
...@@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding, &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
#ifndef USE_ROCM
// Quantization ops // Quantization ops
#ifndef USE_ROCM
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif #endif
......
...@@ -339,6 +339,17 @@ if _is_cuda(): ...@@ -339,6 +339,17 @@ if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
vllm_extension_sources.append("csrc/custom_all_reduce.cu") vllm_extension_sources.append("csrc/custom_all_reduce.cu")
# Add MoE kernels.
ext_modules.append(
CUDAExtension(
name="vllm._moe_C",
sources=glob("csrc/moe/*.cu") + glob("csrc/moe/*.cpp"),
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
))
if not _is_neuron(): if not _is_neuron():
vllm_extension = CUDAExtension( vllm_extension = CUDAExtension(
name="vllm._C", name="vllm._C",
......
...@@ -2,10 +2,8 @@ ...@@ -2,10 +2,8 @@
Run `pytest tests/kernels/test_moe.py`. Run `pytest tests/kernels/test_moe.py`.
""" """
import pytest import pytest
import torch import torch
from transformers import MixtralConfig from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
...@@ -14,22 +12,21 @@ from vllm.model_executor.layers.activation import SiluAndMul ...@@ -14,22 +12,21 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
def torch_moe(a, w1, w2, topk_weight, topk_ids): def torch_moe(a, w1, w2, score, topk):
B, D = a.shape B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk_ids.shape[1], out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
w2.shape[1], score = torch.softmax(score, dim=-1, dtype=torch.float32)
dtype=a.dtype, topk_weight, topk_ids = torch.topk(score, topk)
device=a.device)
topk_ids = topk_ids.view(-1)
topk_weight = topk_weight.view(-1) topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]): for i in range(w1.shape[0]):
mask = topk_ids == i mask = topk_ids == i
if mask.sum(): if mask.sum():
out[mask] = SiluAndMul()( out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) * return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1)).sum(dim=1) topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("m", [512, 222, 33, 1])
...@@ -51,11 +48,8 @@ def test_fused_moe( ...@@ -51,11 +48,8 @@ def test_fused_moe(
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
score = torch.randn((m, e), device='cuda', dtype=dtype) score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.softmax(score, dim=-1) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
topk_weight, topk_ids = torch.topk(score, topk) torch_output = torch_moe(a, w1, w2, score, topk)
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
...@@ -75,7 +69,7 @@ def test_mixtral_moe(dtype: torch.dtype): ...@@ -75,7 +69,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
params_dtype=dtype, params_dtype=dtype,
tp_size=1, tp_size=1,
) ).cuda()
# Load the weights # Load the weights
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
......
...@@ -4,6 +4,7 @@ import triton ...@@ -4,6 +4,7 @@ import triton
import triton.language as tl import triton.language as tl
from vllm._C import ops from vllm._C import ops
from vllm.utils import is_hip
@triton.jit @triton.jit
...@@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, config: dict): mul_routed_weight: bool, top_k: int, config: dict):
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -210,12 +210,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -210,12 +210,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
) )
def fused_moe(hidden_states: torch.Tensor, def fused_moe(
w1: torch.Tensor, hidden_states: torch.Tensor,
w2: torch.Tensor, w1: torch.Tensor,
topk_weights: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, gating_output: torch.Tensor,
inplace=False): topk: int,
renormalize: bool,
inplace: bool = False,
) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
...@@ -223,15 +226,19 @@ def fused_moe(hidden_states: torch.Tensor, ...@@ -223,15 +226,19 @@ def fused_moe(hidden_states: torch.Tensor,
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The weights for the top-k selected experts. - gating_output (torch.Tensor): The output of the gating operation (before softmax).
- topk_ids (torch.Tensor): The indices of the top-k selected experts. - topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False. - inplace (bool): If True, perform the operation in-place. Defaults to False.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
# Check constraints. # Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
...@@ -241,6 +248,37 @@ def fused_moe(hidden_states: torch.Tensor, ...@@ -241,6 +248,37 @@ def fused_moe(hidden_states: torch.Tensor,
M, _ = hidden_states.shape M, _ = hidden_states.shape
E, N, _ = w1.shape E, N, _ = w1.shape
if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
config = { config = {
'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_N': 64,
......
...@@ -25,7 +25,6 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -25,7 +25,6 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
...@@ -155,20 +154,12 @@ class DeepseekMoE(nn.Module): ...@@ -155,20 +154,12 @@ class DeepseekMoE(nn.Module):
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (batch * sequence_length, n_experts) # router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
if self.config.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = fused_moe(hidden_states, final_hidden_states = fused_moe(hidden_states,
self.w1, self.w1,
self.w2, self.w2,
routing_weights, router_logits,
selected_experts, self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True) inplace=True)
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
......
...@@ -24,8 +24,6 @@ ...@@ -24,8 +24,6 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
...@@ -128,18 +126,12 @@ class MixtralMoE(nn.Module): ...@@ -128,18 +126,12 @@ class MixtralMoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts) # router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = fused_moe(hidden_states, final_hidden_states = fused_moe(hidden_states,
self.ws, self.ws,
self.w2s, self.w2s,
routing_weights, router_logits,
selected_experts, self.top_k,
renormalize=True,
inplace=True) inplace=True)
if self.tp_size > 1: if self.tp_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment