Unverified Commit f49777ba authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

Deepseek v3 (#11502)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
Co-authored-by: default avatarmgoin <michael@neuralmagic.com>
Co-authored-by: default avatarrobertgshaw2-neuralmagic <rshaw@neuralmagic.com>
parent 55fb97f7
...@@ -113,6 +113,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, ...@@ -113,6 +113,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
} }
} }
// TODO(simon): this is temporarily adapted from
// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
// we did this to unblock Deepseek V3 but there should be a better
// implementation to manage shared memory.
template <typename scalar_t>
__global__ void moe_align_block_size_global_mem_kernel(
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}
template <typename scalar_t, int TOPK> template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel( __global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
...@@ -137,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -137,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // If we have very large number of experts, we can no longer use shared
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` // memory.
// tensors // TODO(simon): the right solution should be calculating the exact right
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); // amount of shared memory and use that. The num_experts >= 256 is just a
const int32_t shared_mem = // temporary solution to unblock Deepseek V3.
((num_thread + 1) * num_experts + (num_experts + 1)) * if (num_experts >= 256) {
sizeof(int32_t); VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// set dynamic shared mem // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>; // tensors
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
(void*)kernel, shared_mem));
kernel<<<1, num_thread, shared_mem, stream>>>( const int32_t mem_tokens_cnts =
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), ((num_experts + 1) * num_experts) * sizeof(int32_t);
experts_ids.data_ptr<int32_t>(), const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, // allocate global memory
topk_ids.numel()); int32_t* tokens_cnts;
}); int32_t* cumsum;
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
cudaMalloc(&cumsum, mem_cumsum);
auto kernel =
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
kernel<<<1, num_thread, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), tokens_cnts, cumsum);
cudaFree(tokens_cnts);
cudaFree(cumsum);
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
// set dynamic shared mem
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<1, num_thread, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
});
}
} }
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
......
...@@ -596,6 +596,12 @@ class ModelConfig: ...@@ -596,6 +596,12 @@ class ModelConfig:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len) self.max_model_len)
if (self.hf_config.model_type == 'deepseek_v3'
and not self.enforce_eager):
logger.warning("CUDA graph is not supported for Deepseek V3 yet, "
"fallback to the eager mode.")
self.enforce_eager = True
def _verify_bnb_config(self) -> None: def _verify_bnb_config(self) -> None:
""" """
The current version of bitsandbytes (0.44.0) with 8-bit models does not The current version of bitsandbytes (0.44.0) with 8-bit models does not
...@@ -712,8 +718,9 @@ class ModelConfig: ...@@ -712,8 +718,9 @@ class ModelConfig:
def get_head_size(self) -> int: def get_head_size(self) -> int:
# TODO remove hard code # TODO remove hard code
if hasattr(self.hf_text_config, "model_type" if hasattr(self.hf_text_config,
) and self.hf_text_config.model_type == 'deepseek_v2': "model_type") and (self.hf_text_config.model_type
in ('deepseek_v2', 'deepseek_v3')):
# FlashAttention supports only head_size 32, 64, 128, 256, # FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256 # we need to pad head_size 192 to 256
return 256 return 256
......
...@@ -476,18 +476,29 @@ def fused_topk( ...@@ -476,18 +476,29 @@ def fused_topk(
return topk_weights, topk_ids return topk_weights, topk_ids
# This is used by the Deepseek-V2 model # This is used by the Deepseek-V2 and Deepseek-V3 model
def grouped_topk(hidden_states: torch.Tensor, def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0): topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
assert hidden_states.shape[0] == gating_output.shape[0], ( assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch") "Number of tokens mismatch")
scores = torch.softmax(gating_output, dim=-1) if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if e_score_correction_bias is not None:
scores.add_(e_score_correction_bias.unsqueeze(0))
num_token = scores.shape[0] num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group, group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group] -1).max(dim=-1).values # [n, n_group]
......
...@@ -73,16 +73,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -73,16 +73,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool, use_grouped_topk: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward(x=x, return self.forward(x=x,
layer=layer, layer=layer,
...@@ -92,19 +94,23 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -92,19 +94,23 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
def forward_cuda( def forward_cuda(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, use_grouped_topk: bool,
top_k: int, top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -114,7 +120,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -114,7 +120,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(hidden_states=x, return fused_experts(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
...@@ -128,21 +136,29 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -128,21 +136,29 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"The CPU backend currently does not support MoE.") "The CPU backend currently does not support MoE.")
def forward_tpu( def forward_tpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, use_grouped_topk: bool,
top_k: int, top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
assert topk_group is None assert topk_group is None
assert custom_routing_function is None assert custom_routing_function is None
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU.")
if e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for TPU.")
return fused_moe_pallas(hidden_states=x, return fused_moe_pallas(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -156,7 +172,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -156,7 +172,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj / This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2). w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
...@@ -190,6 +206,8 @@ class FusedMoE(torch.nn.Module): ...@@ -190,6 +206,8 @@ class FusedMoE(torch.nn.Module):
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
): ):
super().__init__() super().__init__()
...@@ -210,6 +228,12 @@ class FusedMoE(torch.nn.Module): ...@@ -210,6 +228,12 @@ class FusedMoE(torch.nn.Module):
self.num_expert_group = num_expert_group self.num_expert_group = num_expert_group
self.topk_group = topk_group self.topk_group = topk_group
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = (
...@@ -446,7 +470,9 @@ class FusedMoE(torch.nn.Module): ...@@ -446,7 +470,9 @@ class FusedMoE(torch.nn.Module):
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None): custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk) fused_topk, grouped_topk)
...@@ -460,7 +486,9 @@ class FusedMoE(torch.nn.Module): ...@@ -460,7 +486,9 @@ class FusedMoE(torch.nn.Module):
topk=top_k, topk=top_k,
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group) topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None: elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
...@@ -489,7 +517,9 @@ class FusedMoE(torch.nn.Module): ...@@ -489,7 +517,9 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function) custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
......
...@@ -605,6 +605,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -605,6 +605,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -617,7 +619,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -617,7 +619,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
return fused_experts( return fused_experts(
x, x,
......
This diff is collapsed.
...@@ -45,6 +45,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -45,6 +45,7 @@ _TEXT_GENERATION_MODELS = {
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment