Unverified Commit a59636bb authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Update grok 1 model (#1095)

parent fe502432
...@@ -88,6 +88,9 @@ def main(args): ...@@ -88,6 +88,9 @@ def main(args):
for i in range(len(states)): for i in range(len(states)):
preds.append(get_answer_value(states[i]["answer"])) preds.append(get_answer_value(states[i]["answer"]))
# print(f"{preds=}")
# print(f"{labels=}")
# Compute accuracy # Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels)) acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID) invalid = np.mean(np.array(preds) == INVALID)
......
...@@ -221,6 +221,7 @@ def correctness_test( ...@@ -221,6 +221,7 @@ def correctness_test(
# Prepare inputs # Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
rank_print(f"{input_ids=}")
if bench_args.cut_len > 0: if bench_args.cut_len > 0:
# Prefill # Prefill
......
...@@ -14,7 +14,6 @@ limitations under the License. ...@@ -14,7 +14,6 @@ limitations under the License.
"""Fused operators for activation layers.""" """Fused operators for activation layers."""
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from flashinfer.activation import silu_and_mul from flashinfer.activation import silu_and_mul
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
......
from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1 # https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
"""Fused MoE kernel.""" """Fused MoE kernel."""
import functools import functools
import json import json
...@@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple ...@@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -373,6 +359,31 @@ def get_default_config( ...@@ -373,6 +359,31 @@ def get_default_config(
return config return config
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
override_config: Optional[Dict[str, Any]] = None,
):
if override_config:
config = override_config
else:
# First try to load optimal config from the file
E, _, N = w2_shape
configs = get_moe_configs(E, N, dtype)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
return config
def fused_topk( def fused_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -403,6 +414,41 @@ def fused_topk( ...@@ -403,6 +414,41 @@ def fused_topk(
return topk_weights, topk_ids return topk_weights, topk_ids
# This is used by the Deepseek-V2 model
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_experts( def fused_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
...@@ -425,25 +471,24 @@ def fused_experts( ...@@ -425,25 +471,24 @@ def fused_experts(
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
M, _ = hidden_states.shape num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape E, N, _ = w1.shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
if override_config: # https://github.com/vllm-project/vllm/issues/5938
config = override_config CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
else: M = min(num_tokens, CHUNK_SIZE)
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) get_config_func = functools.partial(
try_get_optimal_moe_config,
if configs: w1.shape,
# If an optimal configuration map has been found, look up the w2.shape,
# optimal config topk_ids.shape[1],
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] "float8" if use_fp8 else None,
else: override_config=override_config,
# Else use the default config
config = get_default_config(
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
) )
config = get_config_func(M)
intermediate_cache1 = torch.empty( intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N), (M, topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
...@@ -460,19 +505,49 @@ def fused_experts( ...@@ -460,19 +505,49 @@ def fused_experts(
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (
chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE, num_tokens),
)
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], E curr_topk_ids, config["BLOCK_SIZE_M"], E
) )
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel( invoke_fused_moe_kernel(
hidden_states, curr_hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
a1_scale, a1_scale,
w1_scale, w1_scale,
topk_weights, curr_topk_weights,
topk_ids, curr_topk_ids,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
...@@ -491,8 +566,8 @@ def fused_experts( ...@@ -491,8 +566,8 @@ def fused_experts(
intermediate_cache3, intermediate_cache3,
a2_scale, a2_scale,
w2_scale, w2_scale,
topk_weights, curr_topk_weights,
topk_ids, curr_topk_ids,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
...@@ -503,13 +578,12 @@ def fused_experts( ...@@ -503,13 +578,12 @@ def fused_experts(
use_fp8=use_fp8, use_fp8=use_fp8,
) )
if inplace: torch.sum(
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape), intermediate_cache3.view(*intermediate_cache3.shape),
dim=1, dim=1,
out=hidden_states, out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
) )
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) return out_hidden_states
def fused_moe( def fused_moe(
...@@ -521,6 +595,9 @@ def fused_moe( ...@@ -521,6 +595,9 @@ def fused_moe(
renormalize: bool, renormalize: bool,
inplace: bool = False, inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
use_fp8: bool = False, use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
...@@ -543,6 +620,10 @@ def fused_moe( ...@@ -543,6 +620,10 @@ def fused_moe(
Defaults to False. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override - override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration. for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
...@@ -556,12 +637,18 @@ def fused_moe( ...@@ -556,12 +637,18 @@ def fused_moe(
# Check constraints. # Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if hasattr(ops, "topk_softmax"): if use_grouped_topk:
topk_weights, topk_ids = fused_topk( assert num_expert_group is not None and topk_group is not None
hidden_states, gating_output, topk, renormalize topk_weights, topk_ids = grouped_topk(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
) )
else: else:
topk_weights, topk_ids = fused_topk_v0_4_3( topk_weights, topk_ids = fused_topk(
hidden_states, gating_output, topk, renormalize hidden_states, gating_output, topk, renormalize
) )
...@@ -579,33 +666,3 @@ def fused_moe( ...@@ -579,33 +666,3 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
) )
def fused_topk_v0_4_3(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
import vllm._moe_C as moe_kernels
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
This diff is collapsed.
...@@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module): ...@@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module):
last_logits = last_logits[:, : self.config.vocab_size].float() last_logits = last_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"): if hasattr(self.config, "final_logit_softcapping"):
last_logits /= self.config.final_logit_softcapping last_logits.div_(self.config.final_logit_softcapping)
last_logits = torch.tanh(last_logits) last_logits = torch.tanh(last_logits)
last_logits *= self.config.final_logit_softcapping last_logits.mul_(self.config.final_logit_softcapping)
# Return only last_logits if logprob is not requested # Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob: if not logits_metadata.return_logprob:
...@@ -209,9 +209,9 @@ class LogitsProcessor(nn.Module): ...@@ -209,9 +209,9 @@ class LogitsProcessor(nn.Module):
all_logits = all_logits[:, : self.config.vocab_size].float() all_logits = all_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"): if hasattr(self.config, "final_logit_softcapping"):
all_logits /= self.config.final_logit_softcapping all_logits.div_(self.config.final_logit_softcapping)
all_logits = torch.tanh(all_logits) all_logits = torch.tanh(all_logits)
all_logits *= self.config.final_logit_softcapping all_logits.mul_(self.config.final_logit_softcapping)
all_logprobs = all_logits all_logprobs = all_logits
del all_logits, hidden_states del all_logits, hidden_states
......
...@@ -53,7 +53,7 @@ from sglang.srt.server_args import ServerArgs ...@@ -53,7 +53,7 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
is_generation_model, is_generation_model,
is_llama3_405b_fp8, is_llama3_405b_fp8_head_16,
is_multimodal_model, is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
...@@ -158,7 +158,7 @@ class ModelRunner: ...@@ -158,7 +158,7 @@ class ModelRunner:
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8: if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8 self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8 vllm_model_config.hf_config.num_key_value_heads = 8
......
This diff is collapsed.
...@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import ( ...@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
......
...@@ -35,7 +35,6 @@ import torch ...@@ -35,7 +35,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from triton.runtime.cache import ( from triton.runtime.cache import (
FileCacheManager, FileCacheManager,
...@@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535): ...@@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}") logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
def is_llama3_405b_fp8(model_config): def is_llama3_405b_fp8_head_16(model_config):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads.""" """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if ( if (
model_config.hf_config.architectures[0] == "LlamaForCausalLM" model_config.hf_config.architectures[0] == "LlamaForCausalLM"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment