Unverified Commit 0fafc560 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

port fp8 mixtral (#460)

parent 19d2135c
...@@ -69,20 +69,13 @@ class ModelRpcServer: ...@@ -69,20 +69,13 @@ class ModelRpcServer:
) )
# For model end global settings # For model end global settings
server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static, mem_fraction_static=server_args.mem_fraction_static,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
load_format=server_args.load_format, server_args=server_args,
trust_remote_code=server_args.trust_remote_code,
server_args_dict=server_args_dict,
) )
if is_multimodal_model(server_args.model_path): if is_multimodal_model(server_args.model_path):
self.processor = get_processor( self.processor = get_processor(
......
...@@ -17,6 +17,7 @@ from vllm.model_executor.models import ModelRegistry ...@@ -17,6 +17,7 @@ from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
...@@ -218,22 +219,23 @@ class ModelRunner: ...@@ -218,22 +219,23 @@ class ModelRunner:
tp_rank, tp_rank,
tp_size, tp_size,
nccl_port, nccl_port,
load_format="auto", server_args: ServerArgs,
trust_remote_code=True,
server_args_dict: dict = {},
): ):
self.model_config = model_config self.model_config = model_config
self.mem_fraction_static = mem_fraction_static self.mem_fraction_static = mem_fraction_static
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.nccl_port = nccl_port self.nccl_port = nccl_port
self.load_format = load_format self.server_args = server_args
self.trust_remote_code = trust_remote_code
global global_server_args_dict global global_server_args_dict
global_server_args_dict = server_args_dict global_server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
# Init torch distributed # Init torch distributed
logger.debug("Init torch begin.")
torch.cuda.set_device(self.tp_rank) torch.cuda.set_device(self.tp_rank)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", backend="nccl",
...@@ -241,13 +243,15 @@ class ModelRunner: ...@@ -241,13 +243,15 @@ class ModelRunner:
rank=self.tp_rank, rank=self.tp_rank,
init_method=f"tcp://127.0.0.1:{self.nccl_port}", init_method=f"tcp://127.0.0.1:{self.nccl_port}",
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
logger.debug("Init torch end.")
total_gpu_memory = get_available_gpu_memory( total_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1 self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30) ) * (1 << 30)
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self.load_model() self.load_model()
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self.init_memory_pool(total_gpu_memory) self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config) self.is_multimodal_model = is_multimodal_model(self.model_config)
...@@ -256,15 +260,15 @@ class ModelRunner: ...@@ -256,15 +260,15 @@ class ModelRunner:
logger.info(f"Rank {self.tp_rank}: load weight begin.") logger.info(f"Rank {self.tp_rank}: load weight begin.")
device_config = DeviceConfig() device_config = DeviceConfig()
load_config = LoadConfig() load_config = LoadConfig(load_format=self.server_args.load_format)
vllm_model_config = VllmModelConfig( vllm_model_config = VllmModelConfig(
model=self.model_config.path, model=self.server_args.model_path,
quantization=self.server_args.quantization,
tokenizer=None, tokenizer=None,
tokenizer_mode=None, tokenizer_mode=None,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.server_args.trust_remote_code,
dtype=torch.float16, dtype=torch.float16,
seed=42, seed=42,
revision=self.model_config.revision,
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
if self.model_config.model_overide_args is not None: if self.model_config.model_overide_args is not None:
...@@ -279,7 +283,7 @@ class ModelRunner: ...@@ -279,7 +283,7 @@ class ModelRunner:
parallel_config=None, parallel_config=None,
scheduler_config=None, scheduler_config=None,
) )
logger.info(f"Rank {self.tp_rank}: load weight end.") logger.info(f"Rank {self.tp_rank}: load weight end. {type(self.model)}")
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
......
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
...@@ -8,11 +8,13 @@ import torch ...@@ -8,11 +8,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm import _custom_ops as ops
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
...@@ -20,12 +22,15 @@ from vllm.model_executor.layers.linear import ( ...@@ -20,12 +22,15 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
...@@ -33,106 +38,196 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -33,106 +38,196 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
class MixtralMLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
)
self.w2 = ReplicatedLinear(
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
)
self.w3 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralMoE(nn.Module):
w1_out, _ = self.w1(hidden_states) """A tensor-parallel MoE implementation for Mixtral that shards each expert
w1_out = self.act_fn(w1_out) across all ranks.
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
class MixtralMoE(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.rank = get_tensor_model_parallel_rank() self.num_total_experts = num_experts
self.tp_size = get_tensor_model_parallel_world_size() self.top_k = top_k
self.num_total_experts = config.num_local_experts self.hidden_size = hidden_size
self.top_k = config.num_experts_per_tok self.intermediate_size = intermediate_size // self.tp_size
if self.tp_size > self.num_total_experts: self.quant_config = quant_config
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " # FIXME(pcmoritz): Make this more general to support different
f"the number of experts {self.num_total_experts}." # quantization schemes
) self.use_fp8 = isinstance(quant_config, Fp8Config)
# Split experts equally between ranks
self.expert_indicies = np.array_split( if params_dtype is None:
range(self.num_total_experts), self.tp_size params_dtype = torch.get_default_dtype()
)[self.rank].tolist() self.params_dtype = params_dtype
if not self.expert_indicies:
raise ValueError(f"Rank {self.rank} has no experts assigned to it.") # Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size,
self.experts = nn.ModuleList( self.num_total_experts,
[ bias=False,
( params_dtype=self.params_dtype,
MixtralMLP( quant_config=None)
self.num_total_experts,
config.hidden_size, if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
config.intermediate_size, params_dtype = torch.float8_e4m3fn
quant_config=quant_config,
) self.w13_weight = nn.Parameter(
if idx in self.expert_indicies torch.empty(self.num_total_experts,
else None 2 * self.intermediate_size,
) self.hidden_size,
for idx in range(self.num_total_experts) dtype=params_dtype))
] self.w2_weight = nn.Parameter(
) torch.empty(self.num_total_experts,
self.gate = ReplicatedLinear( self.hidden_size,
config.hidden_size, self.num_total_experts, bias=False, quant_config=None self.intermediate_size,
) dtype=params_dtype))
set_weight_attrs(self.w13_weight, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_weight, {
"weight_loader": self.weight_loader,
})
# Used for fp8.
self.w13_scale = None
self.w2_scale = None
self.a13_scale = None
self.a2_scale = None
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_scale, {
"weight_loader": self.weight_loader,
})
# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
self.a2_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
set_weight_attrs(self.a13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2_scale, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name:
param_data[expert_id] = loaded_weight
def process_weights_after_loading(self):
# Fp8 is the only case where we need to process after loading.
if not self.use_fp8:
return
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[
expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], self.w2_scale[
expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :])
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ")
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) self.w13_weight,
routing_weights, selected_experts = torch.topk( self.w2_weight,
routing_weights, self.top_k, dim=-1 router_logits,
) self.top_k,
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) renormalize=True,
inplace=True,
final_hidden_states = None use_fp8=self.use_fp8,
for expert_idx in self.expert_indicies: w1_scale=self.w13_scale,
expert_layer = self.experts[expert_idx] w2_scale=self.w2_scale,
expert_mask = selected_experts == expert_idx a1_scale=self.a13_scale,
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True) a2_scale=self.a2_scale)
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights) if self.tp_size > 1:
if final_hidden_states is None: final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states = current_hidden_states final_hidden_states)
else:
final_hidden_states.add_(current_hidden_states) return final_hidden_states.view(num_tokens, hidden_size)
return tensor_model_parallel_all_reduce(final_hidden_states)
class MixtralAttention(nn.Module): class MixtralAttention(nn.Module):
...@@ -234,7 +329,12 @@ class MixtralDecoderLayer(nn.Module): ...@@ -234,7 +329,12 @@ class MixtralDecoderLayer(nn.Module):
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
quant_config=quant_config, quant_config=quant_config,
) )
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
...@@ -342,11 +442,35 @@ class MixtralForCausalLM(nn.Module): ...@@ -342,11 +442,35 @@ class MixtralForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
expert_params_mapping = [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
...@@ -358,15 +482,30 @@ class MixtralForCausalLM(nn.Module): ...@@ -358,15 +482,30 @@ class MixtralForCausalLM(nn.Module):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models. for param_name, weight_name, expert_id in expert_params_mapping:
if name.endswith(".bias") and name not in params_dict: if weight_name not in name:
continue continue
# Skip experts that are not assigned to this worker. name = name.replace(weight_name, param_name)
if "block_sparse_moe.experts." in name and name not in params_dict: param = params_dict[name]
continue weight_loader = param.weight_loader
param = params_dict[name] weight_loader(param,
weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight,
weight_loader(param, loaded_weight) weight_name,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
EntryClass = MixtralForCausalLM EntryClass = MixtralForCausalLM
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
"""Inference-only Mixtral model."""
from typing import Iterable, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
class MixtralMLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
)
self.w2 = ReplicatedLinear(
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
)
self.w3 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}."
)
# Split experts equally between ranks
self.expert_indicies = np.array_split(
range(self.num_total_experts), self.tp_size
)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList(
[
(
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
if idx in self.expert_indicies
else None
)
for idx in range(self.num_total_experts)
]
)
self.gate = ReplicatedLinear(
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = selected_experts == expert_idx
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states)
class MixtralAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
quant_config=quant_config,
)
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
MixtralDecoderLayer(config, i, quant_config=quant_config)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class QuantMixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = MixtralModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if "block_sparse_moe.experts." in name and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = QuantMixtralForCausalLM
...@@ -15,6 +15,7 @@ class ServerArgs: ...@@ -15,6 +15,7 @@ class ServerArgs:
chat_template: Optional[str] = None chat_template: Optional[str] = None
trust_remote_code: bool = True trust_remote_code: bool = True
context_length: Optional[int] = None context_length: Optional[int] = None
quantization: Optional[str] = None
# Port # Port
host: str = "127.0.0.1" host: str = "127.0.0.1"
...@@ -135,6 +136,12 @@ class ServerArgs: ...@@ -135,6 +136,12 @@ class ServerArgs:
default=ServerArgs.context_length, default=ServerArgs.context_length,
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
) )
parser.add_argument(
"--quantization",
type=str,
default=ServerArgs.quantization,
help="The quantization method.",
)
parser.add_argument( parser.add_argument(
"--mem-fraction-static", "--mem-fraction-static",
type=float, type=float,
......
...@@ -106,6 +106,7 @@ def get_available_gpu_memory(gpu_id, distributed=True): ...@@ -106,6 +106,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
"which may cause useless memory allocation for torch CUDA context.", "which may cause useless memory allocation for torch CUDA context.",
) )
torch.cuda.empty_cache()
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
if distributed: if distributed:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment