Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
0fafc560
"docs/source/_figures/max.tex" did not exist on "722b7cb7d39394aeeb02bbc04bae5d7f07306744"
Unverified
Commit
0fafc560
authored
May 21, 2024
by
Lianmin Zheng
Committed by
GitHub
May 21, 2024
Browse files
port fp8 mixtral (#460)
parent
19d2135c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
636 additions
and
121 deletions
+636
-121
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+1
-8
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+16
-12
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+240
-101
python/sglang/srt/models/mixtral_quant.py
python/sglang/srt/models/mixtral_quant.py
+371
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-0
No files found.
python/sglang/srt/managers/router/model_rpc.py
View file @
0fafc560
...
@@ -69,20 +69,13 @@ class ModelRpcServer:
...
@@ -69,20 +69,13 @@ class ModelRpcServer:
)
)
# For model end global settings
# For model end global settings
server_args_dict
=
{
"enable_flashinfer"
:
server_args
.
enable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
self
.
model_runner
=
ModelRunner
(
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
tp_size
=
server_args
.
tp_size
,
nccl_port
=
port_args
.
nccl_port
,
nccl_port
=
port_args
.
nccl_port
,
load_format
=
server_args
.
load_format
,
server_args
=
server_args
,
trust_remote_code
=
server_args
.
trust_remote_code
,
server_args_dict
=
server_args_dict
,
)
)
if
is_multimodal_model
(
server_args
.
model_path
):
if
is_multimodal_model
(
server_args
.
model_path
):
self
.
processor
=
get_processor
(
self
.
processor
=
get_processor
(
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
0fafc560
...
@@ -17,6 +17,7 @@ from vllm.model_executor.models import ModelRegistry
...
@@ -17,6 +17,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
...
@@ -218,22 +219,23 @@ class ModelRunner:
...
@@ -218,22 +219,23 @@ class ModelRunner:
tp_rank
,
tp_rank
,
tp_size
,
tp_size
,
nccl_port
,
nccl_port
,
load_format
=
"auto"
,
server_args
:
ServerArgs
,
trust_remote_code
=
True
,
server_args_dict
:
dict
=
{},
):
):
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
mem_fraction_static
=
mem_fraction_static
self
.
mem_fraction_static
=
mem_fraction_static
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
tp_size
self
.
tp_size
=
tp_size
self
.
nccl_port
=
nccl_port
self
.
nccl_port
=
nccl_port
self
.
load_format
=
load_format
self
.
server_args
=
server_args
self
.
trust_remote_code
=
trust_remote_code
global
global_server_args_dict
global
global_server_args_dict
global_server_args_dict
=
server_args_dict
global_server_args_dict
=
{
"enable_flashinfer"
:
server_args
.
enable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
# Init torch distributed
# Init torch distributed
logger
.
debug
(
"Init torch begin."
)
torch
.
cuda
.
set_device
(
self
.
tp_rank
)
torch
.
cuda
.
set_device
(
self
.
tp_rank
)
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
backend
=
"nccl"
,
...
@@ -241,13 +243,15 @@ class ModelRunner:
...
@@ -241,13 +243,15 @@ class ModelRunner:
rank
=
self
.
tp_rank
,
rank
=
self
.
tp_rank
,
init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
)
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
logger
.
debug
(
"Init torch end."
)
total_gpu_memory
=
get_available_gpu_memory
(
total_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
*
(
1
<<
30
)
)
*
(
1
<<
30
)
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self
.
load_model
()
self
.
load_model
()
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
...
@@ -256,15 +260,15 @@ class ModelRunner:
...
@@ -256,15 +260,15 @@ class ModelRunner:
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight begin."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight begin."
)
device_config
=
DeviceConfig
()
device_config
=
DeviceConfig
()
load_config
=
LoadConfig
()
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
vllm_model_config
=
VllmModelConfig
(
vllm_model_config
=
VllmModelConfig
(
model
=
self
.
model_config
.
path
,
model
=
self
.
server_args
.
model_path
,
quantization
=
self
.
server_args
.
quantization
,
tokenizer
=
None
,
tokenizer
=
None
,
tokenizer_mode
=
None
,
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
seed
=
42
,
seed
=
42
,
revision
=
self
.
model_config
.
revision
,
skip_tokenizer_init
=
True
,
skip_tokenizer_init
=
True
,
)
)
if
self
.
model_config
.
model_overide_args
is
not
None
:
if
self
.
model_config
.
model_overide_args
is
not
None
:
...
@@ -279,7 +283,7 @@ class ModelRunner:
...
@@ -279,7 +283,7 @@ class ModelRunner:
parallel_config
=
None
,
parallel_config
=
None
,
scheduler_config
=
None
,
scheduler_config
=
None
,
)
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight end."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight end.
{
type
(
self
.
model
)
}
"
)
def
profile_max_num_token
(
self
,
total_gpu_memory
):
def
profile_max_num_token
(
self
,
total_gpu_memory
):
available_gpu_memory
=
get_available_gpu_memory
(
available_gpu_memory
=
get_available_gpu_memory
(
...
...
python/sglang/srt/models/mixtral.py
View file @
0fafc560
# Adapted from
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral
_quant
.py#L1
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Mixtral model."""
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
...
@@ -8,11 +8,13 @@ import torch
...
@@ -8,11 +8,13 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -20,12 +22,15 @@ from vllm.model_executor.layers.linear import (
...
@@ -20,12 +22,15 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
print_warning_once
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
...
@@ -33,106 +38,196 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -33,106 +38,196 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
class
MixtralMLP
(
nn
.
Module
):
def
__init__
(
self
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_experts
=
num_experts
self
.
ffn_dim
=
intermediate_size
self
.
hidden_dim
=
hidden_size
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
quant_config
=
quant_config
)
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
class
MixtralMoE
(
nn
.
Module
):
w1_out
,
_
=
self
.
w1
(
hidden_states
)
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
w1_out
=
self
.
act_fn
(
w1_out
)
across all ranks.
w3_out
,
_
=
self
.
w3
(
hidden_states
)
current_hidden_states
=
w1_out
*
w3_out
current_hidden_states
,
_
=
self
.
w2
(
current_hidden_states
)
return
current_hidden_states
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
class
MixtralMoE
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MixtralConfig
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
num_total_experts
=
num_experts
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
top_k
=
top_k
self
.
num_total_experts
=
config
.
num_local_experts
self
.
hidden_size
=
hidden_size
self
.
top_k
=
config
.
num_experts_per_tok
self
.
intermediate_size
=
intermediate_size
//
self
.
tp_size
if
self
.
tp_size
>
self
.
num_total_experts
:
self
.
quant_config
=
quant_config
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
# FIXME(pcmoritz): Make this more general to support different
f
"the number of experts
{
self
.
num_total_experts
}
."
# quantization schemes
)
self
.
use_fp8
=
isinstance
(
quant_config
,
Fp8Config
)
# Split experts equally between ranks
self
.
expert_indicies
=
np
.
array_split
(
if
params_dtype
is
None
:
range
(
self
.
num_total_experts
),
self
.
tp_size
params_dtype
=
torch
.
get_default_dtype
()
)[
self
.
rank
].
tolist
()
self
.
params_dtype
=
params_dtype
if
not
self
.
expert_indicies
:
raise
ValueError
(
f
"Rank
{
self
.
rank
}
has no experts assigned to it."
)
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
experts
=
nn
.
ModuleList
(
self
.
num_total_experts
,
[
bias
=
False
,
(
params_dtype
=
self
.
params_dtype
,
MixtralMLP
(
quant_config
=
None
)
self
.
num_total_experts
,
config
.
hidden_size
,
if
self
.
use_fp8
and
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
config
.
intermediate_size
,
params_dtype
=
torch
.
float8_e4m3fn
quant_config
=
quant_config
,
)
self
.
w13_weight
=
nn
.
Parameter
(
if
idx
in
self
.
expert_indicies
torch
.
empty
(
self
.
num_total_experts
,
else
None
2
*
self
.
intermediate_size
,
)
self
.
hidden_size
,
for
idx
in
range
(
self
.
num_total_experts
)
dtype
=
params_dtype
))
]
self
.
w2_weight
=
nn
.
Parameter
(
)
torch
.
empty
(
self
.
num_total_experts
,
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
config
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
quant_config
=
None
self
.
intermediate_size
,
)
dtype
=
params_dtype
))
set_weight_attrs
(
self
.
w13_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
})
# Used for fp8.
self
.
w13_scale
=
None
self
.
w2_scale
=
None
self
.
a13_scale
=
None
self
.
a2_scale
=
None
if
self
.
use_fp8
:
# WEIGHT_SCALE (for fp8)
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
self
.
w13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
# ACT_SCALE (for fp8)
if
quant_config
.
activation_scheme
==
"static"
:
if
not
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
set_weight_attrs
(
self
.
a13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
a2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
if
"act_scale"
in
weight_name
or
"weight_scale"
in
weight_name
:
param_data
[
expert_id
]
=
loaded_weight
def
process_weights_after_loading
(
self
):
# Fp8 is the only case where we need to process after loading.
if
not
self
.
use_fp8
:
return
# If checkpoint is fp16, quantize here.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
self
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
for
expert
in
range
(
self
.
num_total_experts
):
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w13_weight
.
data
[
expert
,
:,
:])
w2_weight
[
expert
,
:,
:],
self
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w2_weight
.
data
[
expert
,
:,
:])
self
.
w13_weight
=
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
)):
print_warning_once
(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. "
)
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
self
.
w13_weight
,
routing_weights
,
selected_experts
=
torch
.
topk
(
self
.
w2_weight
,
routing_weights
,
self
.
top_k
,
dim
=-
1
router_logits
,
)
self
.
top_k
,
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
renormalize
=
True
,
inplace
=
True
,
final_hidden_states
=
None
use_fp8
=
self
.
use_fp8
,
for
expert_idx
in
self
.
expert_indicies
:
w1_scale
=
self
.
w13_scale
,
expert_layer
=
self
.
experts
[
expert_idx
]
w2_scale
=
self
.
w2_scale
,
expert_mask
=
selected_experts
==
expert_idx
a1_scale
=
self
.
a13_scale
,
expert_weights
=
(
routing_weights
*
expert_mask
).
sum
(
dim
=-
1
,
keepdim
=
True
)
a2_scale
=
self
.
a2_scale
)
current_hidden_states
=
expert_layer
(
hidden_states
).
mul_
(
expert_weights
)
if
self
.
tp_size
>
1
:
if
final_hidden_states
is
None
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
current_hidden_states
final_hidden_states
)
else
:
final_hidden_states
.
add_
(
current_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
class
MixtralAttention
(
nn
.
Module
):
class
MixtralAttention
(
nn
.
Module
):
...
@@ -234,7 +329,12 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -234,7 +329,12 @@ class MixtralDecoderLayer(nn.Module):
sliding_window
=
config
.
sliding_window
,
sliding_window
=
config
.
sliding_window
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
quant_config
=
quant_config
)
self
.
block_sparse_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
...
@@ -342,11 +442,35 @@ class MixtralForCausalLM(nn.Module):
...
@@ -342,11 +442,35 @@ class MixtralForCausalLM(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
]
expert_params_mapping
=
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
(
"w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.act_scale"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
...
@@ -358,15 +482,30 @@ class MixtralForCausalLM(nn.Module):
...
@@ -358,15 +482,30 @@ class MixtralForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
weight_name
not
in
name
:
continue
continue
# Skip experts that are not assigned to this worker.
name
=
name
.
replace
(
weight_name
,
param_name
)
if
"block_sparse_moe.experts."
in
name
and
name
not
in
params_dict
:
param
=
params_dict
[
name
]
continue
weight_loader
=
param
.
weight_loader
param
=
params_dict
[
name
]
weight_loader
(
param
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
,
weight_loader
(
param
,
loaded_weight
)
weight_name
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
EntryClass
=
MixtralForCausalLM
EntryClass
=
MixtralForCausalLM
python/sglang/srt/models/mixtral_quant.py
0 → 100644
View file @
0fafc560
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
class
MixtralMLP
(
nn
.
Module
):
def
__init__
(
self
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_experts
=
num_experts
self
.
ffn_dim
=
intermediate_size
self
.
hidden_dim
=
hidden_size
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
quant_config
=
quant_config
)
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
w1_out
,
_
=
self
.
w1
(
hidden_states
)
w1_out
=
self
.
act_fn
(
w1_out
)
w3_out
,
_
=
self
.
w3
(
hidden_states
)
current_hidden_states
=
w1_out
*
w3_out
current_hidden_states
,
_
=
self
.
w2
(
current_hidden_states
)
return
current_hidden_states
class
MixtralMoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
num_local_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
num_total_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
num_total_experts
}
."
)
# Split experts equally between ranks
self
.
expert_indicies
=
np
.
array_split
(
range
(
self
.
num_total_experts
),
self
.
tp_size
)[
self
.
rank
].
tolist
()
if
not
self
.
expert_indicies
:
raise
ValueError
(
f
"Rank
{
self
.
rank
}
has no experts assigned to it."
)
self
.
experts
=
nn
.
ModuleList
(
[
(
MixtralMLP
(
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
,
)
if
idx
in
self
.
expert_indicies
else
None
)
for
idx
in
range
(
self
.
num_total_experts
)
]
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
quant_config
=
None
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
None
for
expert_idx
in
self
.
expert_indicies
:
expert_layer
=
self
.
experts
[
expert_idx
]
expert_mask
=
selected_experts
==
expert_idx
expert_weights
=
(
routing_weights
*
expert_mask
).
sum
(
dim
=-
1
,
keepdim
=
True
)
current_hidden_states
=
expert_layer
(
hidden_states
).
mul_
(
expert_weights
)
if
final_hidden_states
is
None
:
final_hidden_states
=
current_hidden_states
else
:
final_hidden_states
.
add_
(
current_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
class
MixtralAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
=
0
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
sliding_window
=
sliding_window
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MixtralDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
# Requires transformers > 4.32.0
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
self_attn
=
MixtralAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
layer_id
=
layer_id
,
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
quant_config
=
quant_config
,
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
block_sparse_moe
(
hidden_states
)
return
hidden_states
,
residual
class
MixtralModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
(
[
MixtralDecoderLayer
(
config
,
i
,
quant_config
=
quant_config
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
QuantMixtralForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
"block_sparse_moe.experts."
in
name
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
QuantMixtralForCausalLM
python/sglang/srt/server_args.py
View file @
0fafc560
...
@@ -15,6 +15,7 @@ class ServerArgs:
...
@@ -15,6 +15,7 @@ class ServerArgs:
chat_template
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
trust_remote_code
:
bool
=
True
trust_remote_code
:
bool
=
True
context_length
:
Optional
[
int
]
=
None
context_length
:
Optional
[
int
]
=
None
quantization
:
Optional
[
str
]
=
None
# Port
# Port
host
:
str
=
"127.0.0.1"
host
:
str
=
"127.0.0.1"
...
@@ -135,6 +136,12 @@ class ServerArgs:
...
@@ -135,6 +136,12 @@ class ServerArgs:
default
=
ServerArgs
.
context_length
,
default
=
ServerArgs
.
context_length
,
help
=
"The model's maximum context length. Defaults to None (will use the value from the model's config.json instead)."
,
help
=
"The model's maximum context length. Defaults to None (will use the value from the model's config.json instead)."
,
)
)
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
ServerArgs
.
quantization
,
help
=
"The quantization method."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mem-fraction-static"
,
"--mem-fraction-static"
,
type
=
float
,
type
=
float
,
...
...
python/sglang/srt/utils.py
View file @
0fafc560
...
@@ -106,6 +106,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
...
@@ -106,6 +106,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
"which may cause useless memory allocation for torch CUDA context."
,
"which may cause useless memory allocation for torch CUDA context."
,
)
)
torch
.
cuda
.
empty_cache
()
free_gpu_memory
,
_
=
torch
.
cuda
.
mem_get_info
(
gpu_id
)
free_gpu_memory
,
_
=
torch
.
cuda
.
mem_get_info
(
gpu_id
)
if
distributed
:
if
distributed
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment