"vscode:/vscode.git/clone" did not exist on "ccaa0bf282683dc1647147271f63ffbd0e972603"
Unverified Commit a59636bb authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Update grok 1 model (#1095)

parent fe502432
......@@ -88,6 +88,9 @@ def main(args):
for i in range(len(states)):
preds.append(get_answer_value(states[i]["answer"]))
# print(f"{preds=}")
# print(f"{labels=}")
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
......
......@@ -221,6 +221,7 @@ def correctness_test(
# Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
rank_print(f"{input_ids=}")
if bench_args.cut_len > 0:
# Prefill
......
......@@ -14,7 +14,6 @@ limitations under the License.
"""Fused operators for activation layers."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from flashinfer.activation import silu_and_mul
from vllm.model_executor.custom_op import CustomOp
......
from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
"""Fused MoE kernel."""
import functools
import json
......@@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple
import torch
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
......@@ -373,6 +359,31 @@ def get_default_config(
return config
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
override_config: Optional[Dict[str, Any]] = None,
):
if override_config:
config = override_config
else:
# First try to load optimal config from the file
E, _, N = w2_shape
configs = get_moe_configs(E, N, dtype)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
return config
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
......@@ -403,6 +414,41 @@ def fused_topk(
return topk_weights, topk_ids
# This is used by the Deepseek-V2 model
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
......@@ -425,24 +471,23 @@ def fused_experts(
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
M, _ = hidden_states.shape
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
)
if override_config:
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
)
config = get_config_func(M)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
......@@ -460,56 +505,85 @@ def fused_experts(
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], E
)
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (
chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE, num_tokens),
)
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, config["BLOCK_SIZE_M"], E
)
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
invoke_fused_moe_kernel(
curr_hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace:
return torch.sum(
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
return out_hidden_states
def fused_moe(
......@@ -521,6 +595,9 @@ def fused_moe(
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
......@@ -543,6 +620,10 @@ def fused_moe(
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
......@@ -556,12 +637,18 @@ def fused_moe(
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if hasattr(ops, "topk_softmax"):
topk_weights, topk_ids = fused_topk(
hidden_states, gating_output, topk, renormalize
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
)
else:
topk_weights, topk_ids = fused_topk_v0_4_3(
topk_weights, topk_ids = fused_topk(
hidden_states, gating_output, topk, renormalize
)
......@@ -579,33 +666,3 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_topk_v0_4_3(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
import vllm._moe_C as moe_kernels
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
This diff is collapsed.
......@@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module):
last_logits = last_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"):
last_logits /= self.config.final_logit_softcapping
last_logits.div_(self.config.final_logit_softcapping)
last_logits = torch.tanh(last_logits)
last_logits *= self.config.final_logit_softcapping
last_logits.mul_(self.config.final_logit_softcapping)
# Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob:
......@@ -209,9 +209,9 @@ class LogitsProcessor(nn.Module):
all_logits = all_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"):
all_logits /= self.config.final_logit_softcapping
all_logits.div_(self.config.final_logit_softcapping)
all_logits = torch.tanh(all_logits)
all_logits *= self.config.final_logit_softcapping
all_logits.mul_(self.config.final_logit_softcapping)
all_logprobs = all_logits
del all_logits, hidden_states
......
......@@ -53,7 +53,7 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
is_generation_model,
is_llama3_405b_fp8,
is_llama3_405b_fp8_head_16,
is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check,
......@@ -158,7 +158,7 @@ class ModelRunner:
skip_tokenizer_init=True,
)
if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8
......
This diff is collapsed.
......@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
......
......@@ -35,7 +35,6 @@ import torch
import torch.distributed as dist
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware
from torch.nn.parameter import Parameter
from triton.runtime.cache import (
FileCacheManager,
......@@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
def is_llama3_405b_fp8(model_config):
def is_llama3_405b_fp8_head_16(model_config):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if (
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment