Unverified Commit a59636bb authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Update grok 1 model (#1095)

parent fe502432
......@@ -88,6 +88,9 @@ def main(args):
for i in range(len(states)):
preds.append(get_answer_value(states[i]["answer"]))
# print(f"{preds=}")
# print(f"{labels=}")
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
......
......@@ -221,6 +221,7 @@ def correctness_test(
# Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
rank_print(f"{input_ids=}")
if bench_args.cut_len > 0:
# Prefill
......
......@@ -14,7 +14,6 @@ limitations under the License.
"""Fused operators for activation layers."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from flashinfer.activation import silu_and_mul
from vllm.model_executor.custom_op import CustomOp
......
from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
"""Fused MoE kernel."""
import functools
import json
......@@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple
import torch
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
......@@ -373,6 +359,31 @@ def get_default_config(
return config
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
override_config: Optional[Dict[str, Any]] = None,
):
if override_config:
config = override_config
else:
# First try to load optimal config from the file
E, _, N = w2_shape
configs = get_moe_configs(E, N, dtype)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
return config
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
......@@ -403,6 +414,41 @@ def fused_topk(
return topk_weights, topk_ids
# This is used by the Deepseek-V2 model
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
......@@ -425,24 +471,23 @@ def fused_experts(
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
M, _ = hidden_states.shape
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
)
if override_config:
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
)
config = get_config_func(M)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
......@@ -460,56 +505,85 @@ def fused_experts(
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], E
)
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (
chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE, num_tokens),
)
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, config["BLOCK_SIZE_M"], E
)
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
invoke_fused_moe_kernel(
curr_hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace:
return torch.sum(
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
return out_hidden_states
def fused_moe(
......@@ -521,6 +595,9 @@ def fused_moe(
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
......@@ -543,6 +620,10 @@ def fused_moe(
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
......@@ -556,12 +637,18 @@ def fused_moe(
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if hasattr(ops, "topk_softmax"):
topk_weights, topk_ids = fused_topk(
hidden_states, gating_output, topk, renormalize
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
)
else:
topk_weights, topk_ids = fused_topk_v0_4_3(
topk_weights, topk_ids = fused_topk(
hidden_states, gating_output, topk, renormalize
)
......@@ -579,33 +666,3 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_topk_v0_4_3(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
import vllm._moe_C as moe_kernels
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
# Adapted from
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
from abc import abstractmethod
from typing import List, Optional, Tuple
import torch
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
raise NotImplementedError
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
return self.forward(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize,
use_grouped_topk,
num_expert_group,
topk_group,
)
def forward_cuda(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from sglang.srt.layers.fused_moe.fused_moe import fused_moe
return fused_moe(
x,
w1,
w2,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
def forward_cpu(self, *args, **kwargs):
raise NotImplementedError("The CPU backend currently does not support MoE.")
def forward_tpu(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.top_k = top_k
self.num_experts = num_experts
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod()
)
else:
if isinstance(quant_config, Fp8Config):
self.quant_method = Fp8MoEMethod(quant_config)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
def weight_loader(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: int,
expert_id: int,
pre_sharded: bool,
):
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if (
param_data[expert_id] != 1
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}"
)
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if shard_id == 0 or shard_id == 2:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == 0 else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else:
param_data[expert_id] = loaded_weight
# Weights
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
if pre_sharded:
shard = slice(None)
else:
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
# w1, gate_proj case: Load into first shard of w13.
if shard_id == 0:
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
# w3, up_proj case: Load into second shard of w13.
elif shard_id == 2:
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
shard, :
]
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
else:
raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
)
if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
) -> List[Tuple[str, str, int, int]]:
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name]
return (
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_scale"
if weight_name in gate_up
else "experts.w2_scale"
),
f"experts.{expert_id}.{weight_name}.weight_scale",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
]
+ [
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_weight"
if weight_name in gate_up
else "experts.w2_weight"
),
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
]
+ [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.a13_scale"
if weight_name in gate_up
else "experts.a2_scale"
),
f"experts.{expert_id}.{weight_name}.input_scale",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
]
)
import torch
from torch.nn import Module
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
per_tensor_dequantize,
)
from vllm.utils import print_warning_once
class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_scale", w13_scale)
w2_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_scale", w2_scale)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
a13_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("a13_scale", a13_scale)
set_weight_attrs(a13_scale, extra_weight_attrs)
a2_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("a2_scale", a2_scale)
set_weight_attrs(a2_scale, extra_weight_attrs)
else:
layer.a13_scale = None
layer.a2_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(
layer.w13_weight.data, dtype=torch.float8_e4m3fn
)
w2_weight = torch.empty_like(
layer.w2_weight.data, dtype=torch.float8_e4m3fn
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_scale = torch.nn.Parameter(
torch.ones(
layer.num_experts, dtype=torch.float32, device=w13_weight.device
),
requires_grad=False,
)
for expert in range(layer.num_experts):
w13_weight[expert, :, :], layer.w13_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant(
layer.w2_weight.data[expert, :, :]
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.quant_config.activation_scheme == "static":
if layer.a13_scale is None or layer.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.a13_scale) or not all_close_1d(
layer.a2_scale
):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer.a13_scale = torch.nn.Parameter(
layer.a13_scale.max(), requires_grad=False
)
layer.a2_scale = torch.nn.Parameter(
layer.a2_scale.max(), requires_grad=False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_scale[expert_id][shard_id],
)
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
return
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
from sglang.srt.layers.fused_moe.fused_moe import fused_moe
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_fp8=True,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
a1_scale=layer.a13_scale,
a2_scale=layer.a2_scale,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
......@@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module):
last_logits = last_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"):
last_logits /= self.config.final_logit_softcapping
last_logits.div_(self.config.final_logit_softcapping)
last_logits = torch.tanh(last_logits)
last_logits *= self.config.final_logit_softcapping
last_logits.mul_(self.config.final_logit_softcapping)
# Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob:
......@@ -209,9 +209,9 @@ class LogitsProcessor(nn.Module):
all_logits = all_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"):
all_logits /= self.config.final_logit_softcapping
all_logits.div_(self.config.final_logit_softcapping)
all_logits = torch.tanh(all_logits)
all_logits *= self.config.final_logit_softcapping
all_logits.mul_(self.config.final_logit_softcapping)
all_logprobs = all_logits
del all_logits, hidden_states
......
......@@ -53,7 +53,7 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
is_generation_model,
is_llama3_405b_fp8,
is_llama3_405b_fp8_head_16,
is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check,
......@@ -158,7 +158,7 @@ class ModelRunner:
skip_tokenizer_init=True,
)
if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8
......
......@@ -16,20 +16,17 @@ limitations under the License.
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
import warnings
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from torch import nn
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.config import CacheConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
......@@ -37,7 +34,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
......@@ -45,141 +41,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
from sglang.srt.layers.fused_moe import fused_moe
from sglang.srt.layers.fused_moe import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
use_fused = True
class Grok1MLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
)
self.w2 = ReplicatedLinear(
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
)
self.w3 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
)
self.act_fn = nn.GELU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class Grok1MoEUnfused(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}."
)
# Split experts equally between ranks
self.expert_indicies = np.array_split(
range(self.num_total_experts), self.tp_size
)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList(
[
(
Grok1MLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
if idx in self.expert_indicies
else None
)
for idx in range(self.num_total_experts)
]
)
self.gate = ReplicatedLinear(
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
router_logits = 30 * F.tanh(router_logits / 30)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
routing_weights = routing_weights.to(hidden_states.dtype)
hidden_dim = hidden_states.shape[1]
final_hidden_states = torch.zeros(
(hidden_states.shape[0], hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_total_experts
).permute(2, 1, 0)
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue
# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = (
expert_layer(current_state)
* routing_weights[top_x_list, idx_list, None]
)
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states)
class Grok1MoE(nn.Module):
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
......@@ -197,221 +65,42 @@ class Grok1MoE(nn.Module):
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(quant_config, Fp8Config)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts,
hidden_size,
num_experts,
bias=False,
params_dtype=self.params_dtype,
params_dtype=params_dtype,
quant_config=None,
)
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype,
)
)
self.w2_weight = nn.Parameter(
torch.empty(
self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype,
)
)
set_weight_attrs(
self.w13_weight,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_weight,
{
"weight_loader": self.weight_loader,
},
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
)
# Used for fp8.
self.w13_scale = None
self.w2_scale = None
self.a13_scale = None
self.a2_scale = None
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
self.w2_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(
self.w13_scale,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_scale,
{
"weight_loader": self.weight_loader,
},
)
# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
self.a13_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
self.a2_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(
self.a13_scale,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.a2_scale,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
pre_sharded: bool,
):
param_data = param.data
shard_size = self.intermediate_size
if pre_sharded:
# The weight is already sharded. Readl the full shard
shard = slice(None)
else:
tp_rank = get_tensor_model_parallel_rank()
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
shard, :
]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name:
param_data[expert_id] = loaded_weight
def process_weights_after_loading(self):
# Fp8 is the only case where we need to process after loading.
if not self.use_fp8:
return
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(
self.w13_weight.data, dtype=torch.float8_e4m3fn
)
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :]
)
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :]
)
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. "
)
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(
hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=False,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale,
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
router_logits = 30.0 * F.tanh(router_logits / 30.0)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
class Grok1Attention(nn.Module):
......@@ -478,6 +167,7 @@ class Grok1Attention(nn.Module):
layer_id=layer_id,
logit_cap=logit_cap,
)
# TODO(lianmin): load logit cap from config
def forward(
self,
......@@ -502,7 +192,7 @@ class Grok1DecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = Grok1Attention(
hidden_size=self.hidden_size,
......@@ -513,18 +203,13 @@ class Grok1DecoderLayer(nn.Module):
rope_theta=rope_theta,
quant_config=quant_config,
)
if use_fused:
self.block_sparse_moe = Grok1MoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
)
else:
self.block_sparse_moe = Grok1MoEUnfused(
config=config, quant_config=quant_config
)
self.block_sparse_moe = Grok1MoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
)
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -536,6 +221,7 @@ class Grok1DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
# Self Attention
hidden_states = (
self.post_attn_norm(
self.self_attn(
......@@ -547,11 +233,11 @@ class Grok1DecoderLayer(nn.Module):
+ hidden_states
)
# Fully Connected
hidden_states = (
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
+ hidden_states
)
return hidden_states
......@@ -593,7 +279,6 @@ class Grok1Model(nn.Module):
for i in range(len(self.layers)):
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
hidden_states = self.norm(hidden_states)
hidden_states.mul_(self.config.output_multiplier_scale)
return hidden_states
......@@ -615,8 +300,8 @@ class Grok1ModelForCausalLM(nn.Module):
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
warnings.filterwarnings("ignore", category=FutureWarning)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......@@ -637,50 +322,17 @@ class Grok1ModelForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"),
]
if use_fused:
expert_params_mapping = (
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
(
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
)
else:
expert_params_mapping = []
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts,
)
params_dict = dict(self.named_parameters())
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
for name, loaded_weight in weights:
# print(get_tensor_model_parallel_rank(), name)
if "rotary_emb.inv_freq" in name:
continue
......@@ -691,21 +343,25 @@ class Grok1ModelForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
weight_name,
shard_id=shard_id,
expert_id=expert_id,
pre_sharded=get_tensor_model_parallel_world_size() > 1,
)
......@@ -714,6 +370,9 @@ class Grok1ModelForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
......@@ -721,11 +380,6 @@ class Grok1ModelForCausalLM(nn.Module):
weight_loader(param, loaded_weight)
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
......
......@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
......
......@@ -35,7 +35,6 @@ import torch
import torch.distributed as dist
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware
from torch.nn.parameter import Parameter
from triton.runtime.cache import (
FileCacheManager,
......@@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
def is_llama3_405b_fp8(model_config):
def is_llama3_405b_fp8_head_16(model_config):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if (
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment