Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0fafc560
Unverified
Commit
0fafc560
authored
May 21, 2024
by
Lianmin Zheng
Committed by
GitHub
May 21, 2024
Browse files
port fp8 mixtral (#460)
parent
19d2135c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
636 additions
and
121 deletions
+636
-121
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+1
-8
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+16
-12
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+240
-101
python/sglang/srt/models/mixtral_quant.py
python/sglang/srt/models/mixtral_quant.py
+371
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-0
No files found.
python/sglang/srt/managers/router/model_rpc.py
View file @
0fafc560
...
@@ -69,20 +69,13 @@ class ModelRpcServer:
...
@@ -69,20 +69,13 @@ class ModelRpcServer:
)
)
# For model end global settings
# For model end global settings
server_args_dict
=
{
"enable_flashinfer"
:
server_args
.
enable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
self
.
model_runner
=
ModelRunner
(
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
tp_size
=
server_args
.
tp_size
,
nccl_port
=
port_args
.
nccl_port
,
nccl_port
=
port_args
.
nccl_port
,
load_format
=
server_args
.
load_format
,
server_args
=
server_args
,
trust_remote_code
=
server_args
.
trust_remote_code
,
server_args_dict
=
server_args_dict
,
)
)
if
is_multimodal_model
(
server_args
.
model_path
):
if
is_multimodal_model
(
server_args
.
model_path
):
self
.
processor
=
get_processor
(
self
.
processor
=
get_processor
(
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
0fafc560
...
@@ -17,6 +17,7 @@ from vllm.model_executor.models import ModelRegistry
...
@@ -17,6 +17,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
...
@@ -218,22 +219,23 @@ class ModelRunner:
...
@@ -218,22 +219,23 @@ class ModelRunner:
tp_rank
,
tp_rank
,
tp_size
,
tp_size
,
nccl_port
,
nccl_port
,
load_format
=
"auto"
,
server_args
:
ServerArgs
,
trust_remote_code
=
True
,
server_args_dict
:
dict
=
{},
):
):
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
mem_fraction_static
=
mem_fraction_static
self
.
mem_fraction_static
=
mem_fraction_static
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
tp_size
self
.
tp_size
=
tp_size
self
.
nccl_port
=
nccl_port
self
.
nccl_port
=
nccl_port
self
.
load_format
=
load_format
self
.
server_args
=
server_args
self
.
trust_remote_code
=
trust_remote_code
global
global_server_args_dict
global
global_server_args_dict
global_server_args_dict
=
server_args_dict
global_server_args_dict
=
{
"enable_flashinfer"
:
server_args
.
enable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
# Init torch distributed
# Init torch distributed
logger
.
debug
(
"Init torch begin."
)
torch
.
cuda
.
set_device
(
self
.
tp_rank
)
torch
.
cuda
.
set_device
(
self
.
tp_rank
)
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
backend
=
"nccl"
,
...
@@ -241,13 +243,15 @@ class ModelRunner:
...
@@ -241,13 +243,15 @@ class ModelRunner:
rank
=
self
.
tp_rank
,
rank
=
self
.
tp_rank
,
init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
)
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
logger
.
debug
(
"Init torch end."
)
total_gpu_memory
=
get_available_gpu_memory
(
total_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
*
(
1
<<
30
)
)
*
(
1
<<
30
)
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self
.
load_model
()
self
.
load_model
()
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
...
@@ -256,15 +260,15 @@ class ModelRunner:
...
@@ -256,15 +260,15 @@ class ModelRunner:
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight begin."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight begin."
)
device_config
=
DeviceConfig
()
device_config
=
DeviceConfig
()
load_config
=
LoadConfig
()
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
vllm_model_config
=
VllmModelConfig
(
vllm_model_config
=
VllmModelConfig
(
model
=
self
.
model_config
.
path
,
model
=
self
.
server_args
.
model_path
,
quantization
=
self
.
server_args
.
quantization
,
tokenizer
=
None
,
tokenizer
=
None
,
tokenizer_mode
=
None
,
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
seed
=
42
,
seed
=
42
,
revision
=
self
.
model_config
.
revision
,
skip_tokenizer_init
=
True
,
skip_tokenizer_init
=
True
,
)
)
if
self
.
model_config
.
model_overide_args
is
not
None
:
if
self
.
model_config
.
model_overide_args
is
not
None
:
...
@@ -279,7 +283,7 @@ class ModelRunner:
...
@@ -279,7 +283,7 @@ class ModelRunner:
parallel_config
=
None
,
parallel_config
=
None
,
scheduler_config
=
None
,
scheduler_config
=
None
,
)
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight end."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight end.
{
type
(
self
.
model
)
}
"
)
def
profile_max_num_token
(
self
,
total_gpu_memory
):
def
profile_max_num_token
(
self
,
total_gpu_memory
):
available_gpu_memory
=
get_available_gpu_memory
(
available_gpu_memory
=
get_available_gpu_memory
(
...
...
python/sglang/srt/models/mixtral.py
View file @
0fafc560
# Adapted from
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral
_quant
.py#L1
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Mixtral model."""
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
...
@@ -8,11 +8,13 @@ import torch
...
@@ -8,11 +8,13 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -20,12 +22,15 @@ from vllm.model_executor.layers.linear import (
...
@@ -20,12 +22,15 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
print_warning_once
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
...
@@ -33,106 +38,196 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -33,106 +38,196 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
class
MixtralMLP
(
nn
.
Module
):
def
__init__
(
self
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_experts
=
num_experts
self
.
ffn_dim
=
intermediate_size
self
.
hidden_dim
=
hidden_size
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
quant_config
=
quant_config
)
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
class
MixtralMoE
(
nn
.
Module
):
w1_out
,
_
=
self
.
w1
(
hidden_states
)
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
w1_out
=
self
.
act_fn
(
w1_out
)
across all ranks.
w3_out
,
_
=
self
.
w3
(
hidden_states
)
current_hidden_states
=
w1_out
*
w3_out
current_hidden_states
,
_
=
self
.
w2
(
current_hidden_states
)
return
current_hidden_states
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
class
MixtralMoE
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MixtralConfig
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
num_total_experts
=
num_experts
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
top_k
=
top_k
self
.
num_total_experts
=
config
.
num_local_experts
self
.
hidden_size
=
hidden_size
self
.
top_k
=
config
.
num_experts_per_tok
self
.
intermediate_size
=
intermediate_size
//
self
.
tp_size
if
self
.
tp_size
>
self
.
num_total_experts
:
self
.
quant_config
=
quant_config
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
# FIXME(pcmoritz): Make this more general to support different
f
"the number of experts
{
self
.
num_total_experts
}
."
# quantization schemes
)
self
.
use_fp8
=
isinstance
(
quant_config
,
Fp8Config
)
# Split experts equally between ranks
self
.
expert_indicies
=
np
.
array_split
(
if
params_dtype
is
None
:
range
(
self
.
num_total_experts
),
self
.
tp_size
params_dtype
=
torch
.
get_default_dtype
()
)[
self
.
rank
].
tolist
()
self
.
params_dtype
=
params_dtype
if
not
self
.
expert_indicies
:
raise
ValueError
(
f
"Rank
{
self
.
rank
}
has no experts assigned to it."
)
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
experts
=
nn
.
ModuleList
(
self
.
num_total_experts
,
[
bias
=
False
,
(
params_dtype
=
self
.
params_dtype
,
MixtralMLP
(
quant_config
=
None
)
self
.
num_total_experts
,
config
.
hidden_size
,
if
self
.
use_fp8
and
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
config
.
intermediate_size
,
params_dtype
=
torch
.
float8_e4m3fn
quant_config
=
quant_config
,
)
self
.
w13_weight
=
nn
.
Parameter
(
if
idx
in
self
.
expert_indicies
torch
.
empty
(
self
.
num_total_experts
,
else
None
2
*
self
.
intermediate_size
,
)
self
.
hidden_size
,
for
idx
in
range
(
self
.
num_total_experts
)
dtype
=
params_dtype
))
]
self
.
w2_weight
=
nn
.
Parameter
(
)
torch
.
empty
(
self
.
num_total_experts
,
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
config
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
quant_config
=
None
self
.
intermediate_size
,
)
dtype
=
params_dtype
))
set_weight_attrs
(
self
.
w13_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
})
# Used for fp8.
self
.
w13_scale
=
None
self
.
w2_scale
=
None
self
.
a13_scale
=
None
self
.
a2_scale
=
None
if
self
.
use_fp8
:
# WEIGHT_SCALE (for fp8)
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
self
.
w13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
# ACT_SCALE (for fp8)
if
quant_config
.
activation_scheme
==
"static"
:
if
not
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
set_weight_attrs
(
self
.
a13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
a2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
if
"act_scale"
in
weight_name
or
"weight_scale"
in
weight_name
:
param_data
[
expert_id
]
=
loaded_weight
def
process_weights_after_loading
(
self
):
# Fp8 is the only case where we need to process after loading.
if
not
self
.
use_fp8
:
return
# If checkpoint is fp16, quantize here.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
self
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
for
expert
in
range
(
self
.
num_total_experts
):
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w13_weight
.
data
[
expert
,
:,
:])
w2_weight
[
expert
,
:,
:],
self
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w2_weight
.
data
[
expert
,
:,
:])
self
.
w13_weight
=
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
)):
print_warning_once
(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. "
)
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
self
.
w13_weight
,
routing_weights
,
selected_experts
=
torch
.
topk
(
self
.
w2_weight
,
routing_weights
,
self
.
top_k
,
dim
=-
1
router_logits
,
)
self
.
top_k
,
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
renormalize
=
True
,
inplace
=
True
,
final_hidden_states
=
None
use_fp8
=
self
.
use_fp8
,
for
expert_idx
in
self
.
expert_indicies
:
w1_scale
=
self
.
w13_scale
,
expert_layer
=
self
.
experts
[
expert_idx
]
w2_scale
=
self
.
w2_scale
,
expert_mask
=
selected_experts
==
expert_idx
a1_scale
=
self
.
a13_scale
,
expert_weights
=
(
routing_weights
*
expert_mask
).
sum
(
dim
=-
1
,
keepdim
=
True
)
a2_scale
=
self
.
a2_scale
)
current_hidden_states
=
expert_layer
(
hidden_states
).
mul_
(
expert_weights
)
if
self
.
tp_size
>
1
:
if
final_hidden_states
is
None
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
current_hidden_states
final_hidden_states
)
else
:
final_hidden_states
.
add_
(
current_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
class
MixtralAttention
(
nn
.
Module
):
class
MixtralAttention
(
nn
.
Module
):
...
@@ -234,7 +329,12 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -234,7 +329,12 @@ class MixtralDecoderLayer(nn.Module):
sliding_window
=
config
.
sliding_window
,
sliding_window
=
config
.
sliding_window
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
quant_config
=
quant_config
)
self
.
block_sparse_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
...
@@ -342,11 +442,35 @@ class MixtralForCausalLM(nn.Module):
...
@@ -342,11 +442,35 @@ class MixtralForCausalLM(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
]
expert_params_mapping
=
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
(
"w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.act_scale"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
...
@@ -358,15 +482,30 @@ class MixtralForCausalLM(nn.Module):
...
@@ -358,15 +482,30 @@ class MixtralForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
weight_name
not
in
name
:
continue
continue
# Skip experts that are not assigned to this worker.
name
=
name
.
replace
(
weight_name
,
param_name
)
if
"block_sparse_moe.experts."
in
name
and
name
not
in
params_dict
:
param
=
params_dict
[
name
]
continue
weight_loader
=
param
.
weight_loader
param
=
params_dict
[
name
]
weight_loader
(
param
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
,
weight_loader
(
param
,
loaded_weight
)
weight_name
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
EntryClass
=
MixtralForCausalLM
EntryClass
=
MixtralForCausalLM
python/sglang/srt/models/mixtral_quant.py
0 → 100644
View file @
0fafc560
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
class
MixtralMLP
(
nn
.
Module
):
def
__init__
(
self
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_experts
=
num_experts
self
.
ffn_dim
=
intermediate_size
self
.
hidden_dim
=
hidden_size
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
quant_config
=
quant_config
)
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
w1_out
,
_
=
self
.
w1
(
hidden_states
)
w1_out
=
self
.
act_fn
(
w1_out
)
w3_out
,
_
=
self
.
w3
(
hidden_states
)
current_hidden_states
=
w1_out
*
w3_out
current_hidden_states
,
_
=
self
.
w2
(
current_hidden_states
)
return
current_hidden_states
class
MixtralMoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
num_local_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
num_total_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
num_total_experts
}
."
)
# Split experts equally between ranks
self
.
expert_indicies
=
np
.
array_split
(
range
(
self
.
num_total_experts
),
self
.
tp_size
)[
self
.
rank
].
tolist
()
if
not
self
.
expert_indicies
:
raise
ValueError
(
f
"Rank
{
self
.
rank
}
has no experts assigned to it."
)
self
.
experts
=
nn
.
ModuleList
(
[
(
MixtralMLP
(
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
,
)
if
idx
in
self
.
expert_indicies
else
None
)
for
idx
in
range
(
self
.
num_total_experts
)
]
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
quant_config
=
None
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
None
for
expert_idx
in
self
.
expert_indicies
:
expert_layer
=
self
.
experts
[
expert_idx
]
expert_mask
=
selected_experts
==
expert_idx
expert_weights
=
(
routing_weights
*
expert_mask
).
sum
(
dim
=-
1
,
keepdim
=
True
)
current_hidden_states
=
expert_layer
(
hidden_states
).
mul_
(
expert_weights
)
if
final_hidden_states
is
None
:
final_hidden_states
=
current_hidden_states
else
:
final_hidden_states
.
add_
(
current_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
class
MixtralAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
=
0
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
sliding_window
=
sliding_window
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MixtralDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
# Requires transformers > 4.32.0
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
self_attn
=
MixtralAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
layer_id
=
layer_id
,
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
quant_config
=
quant_config
,
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
block_sparse_moe
(
hidden_states
)
return
hidden_states
,
residual
class
MixtralModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
(
[
MixtralDecoderLayer
(
config
,
i
,
quant_config
=
quant_config
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
QuantMixtralForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
"block_sparse_moe.experts."
in
name
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
QuantMixtralForCausalLM
python/sglang/srt/server_args.py
View file @
0fafc560
...
@@ -15,6 +15,7 @@ class ServerArgs:
...
@@ -15,6 +15,7 @@ class ServerArgs:
chat_template
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
trust_remote_code
:
bool
=
True
trust_remote_code
:
bool
=
True
context_length
:
Optional
[
int
]
=
None
context_length
:
Optional
[
int
]
=
None
quantization
:
Optional
[
str
]
=
None
# Port
# Port
host
:
str
=
"127.0.0.1"
host
:
str
=
"127.0.0.1"
...
@@ -135,6 +136,12 @@ class ServerArgs:
...
@@ -135,6 +136,12 @@ class ServerArgs:
default
=
ServerArgs
.
context_length
,
default
=
ServerArgs
.
context_length
,
help
=
"The model's maximum context length. Defaults to None (will use the value from the model's config.json instead)."
,
help
=
"The model's maximum context length. Defaults to None (will use the value from the model's config.json instead)."
,
)
)
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
ServerArgs
.
quantization
,
help
=
"The quantization method."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mem-fraction-static"
,
"--mem-fraction-static"
,
type
=
float
,
type
=
float
,
...
...
python/sglang/srt/utils.py
View file @
0fafc560
...
@@ -106,6 +106,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
...
@@ -106,6 +106,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
"which may cause useless memory allocation for torch CUDA context."
,
"which may cause useless memory allocation for torch CUDA context."
,
)
)
torch
.
cuda
.
empty_cache
()
free_gpu_memory
,
_
=
torch
.
cuda
.
mem_get_info
(
gpu_id
)
free_gpu_memory
,
_
=
torch
.
cuda
.
mem_get_info
(
gpu_id
)
if
distributed
:
if
distributed
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment