Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
2d96da81
Unverified
Commit
2d96da81
authored
Jul 19, 2024
by
Ying Sheng
Committed by
GitHub
Jul 19, 2024
Browse files
refactor model loader [unreachable code]: initial refactor (#655)
parent
c126a6cc
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
2116 additions
and
0 deletions
+2116
-0
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+869
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+49
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+662
-0
python/sglang/srt/model_loader/model_loader.py
python/sglang/srt/model_loader/model_loader.py
+276
-0
python/sglang/srt/model_loader/utils.py
python/sglang/srt/model_loader/utils.py
+260
-0
No files found.
python/sglang/srt/layers/linear.py
0 → 100644
View file @
2d96da81
# temporarily adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/linear.py
# FIXME: refactor the linear abstraction
from
abc
import
abstractmethod
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
marlin_tile_size
=
getattr
(
param
,
"marlin_tile_size"
,
None
)
if
marlin_tile_size
is
None
:
return
shard_size
,
shard_offset
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
def
adjust_bitsandbytes_shard
(
param
:
Parameter
,
qkv_offsets
:
Dict
[
str
,
Tuple
[
int
,
int
]],
loaded_shard_id
:
str
)
->
Tuple
[
int
,
int
]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total
,
_
=
qkv_offsets
[
"total"
]
orig_offset
,
orig_size
=
qkv_offsets
[
loaded_shard_id
]
quantized_total
=
param
.
data
.
shape
[
0
]
quantized_offset
=
orig_offset
*
quantized_total
//
total
quantized_size
=
orig_size
*
quantized_total
//
total
return
quantized_size
,
quantized_offset
def
adjust_scalar_to_fused_array
(
param
,
loaded_weight
,
shard_id
):
"""For fused modules (QKV and MLP) we have an array of length
N that holds 1 scale for each "logical" matrix. So the param
is an array of length N. The loaded_weight corresponds to
one of the shards on disk. Here, we slice the param based on
the shard_id for loading.
"""
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
if
isinstance
(
shard_id
,
str
):
shard_id
=
qkv_idxs
[
shard_id
]
elif
not
isinstance
(
shard_id
,
int
):
raise
ValueError
(
f
"Unknown Shard Id
{
shard_id
}
"
)
# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if
len
(
loaded_weight
.
shape
)
!=
0
:
assert
loaded_weight
.
shape
[
0
]
==
1
loaded_weight
=
loaded_weight
[
0
]
return
param
[
shard_id
],
loaded_weight
class
LinearMethodBase
(
QuantizeMethodBase
):
"""Base class for different (maybe quantized) linear methods."""
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
class
UnquantizedLinearMethod
(
LinearMethodBase
):
"""Linear method without quantization.
Args:
separate_bias_add: If true, add bias separately after matrix
multiplication.
"""
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
weight
=
layer
.
weight
if
self
.
separate_bias_add
:
if
bias
is
not
None
:
return
F
.
linear
(
x
,
weight
)
+
bias
return
F
.
linear
(
x
,
weight
)
return
F
.
linear
(
x
,
weight
,
bias
)
class
LinearBase
(
torch
.
nn
.
Module
):
"""Base linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
# Keep input parameters
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
skip_bias_add
=
skip_bias_add
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
ReplicatedLinear
(
LinearBase
):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
)
)
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", output_features=
{
self
.
output_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
return
s
class
ColumnParallelLinear
(
LinearBase
):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
quant_method
is
not
None
self
.
output_size_per_partition
=
divide
(
self
.
output_size
,
tp_size
)
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
# If QKV or MergedColumn, use output size of each partition.
if
hasattr
(
self
,
"output_sizes"
):
self
.
output_partition_sizes
=
[
divide
(
output_size
,
tp_size
)
for
output_size
in
self
.
output_sizes
]
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
self
.
quant_method
.
create_weights
(
layer
=
self
,
input_size_per_partition
=
self
.
input_size
,
output_partition_sizes
=
self
.
output_partition_sizes
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
,
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
)
)
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
},
)
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
if
param
.
data
.
dtype
!=
loaded_weight
.
dtype
:
param
.
data
=
torch
.
empty_like
(
param
.
data
,
dtype
=
loaded_weight
.
dtype
,
device
=
"cuda"
)
tp_rank
=
get_tensor_model_parallel_rank
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
param_data
=
param
.
data
if
output_dim
is
not
None
:
shard_size
=
param_data
.
shape
[
output_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", output_features=
{
self
.
output_size_per_partition
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
s
+=
f
", gather_output=
{
self
.
gather_output
}
"
return
s
class
MergedColumnParallelLinear
(
ColumnParallelLinear
):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make the output
available to all GPUs, otherwise, every GPU will have
its own output.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_sizes
:
List
[
int
],
bias
:
bool
=
True
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
super
().
__init__
(
input_size
=
input_size
,
output_size
=
sum
(
output_sizes
),
bias
=
bias
,
gather_output
=
gather_output
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
,
):
if
param
.
data
.
dtype
!=
loaded_weight
.
dtype
:
param
.
data
=
torch
.
empty_like
(
param
.
data
,
dtype
=
loaded_weight
.
dtype
,
device
=
"cuda"
)
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
if
loaded_shard_id
is
None
:
# Loaded weight is already fused on disk (qkv/mlp).
if
output_dim
is
None
:
if
needs_scalar_to_array
is
not
None
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
return
current_shard_offset
=
0
shard_offsets
:
List
[
Tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
return
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
output_dim
is
not
None
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
shard_size
=
loaded_weight
.
shape
[
output_dim
]
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
loaded_shard_id
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for AQLM codebooks.
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
shard_size
=
loaded_weight
.
shape
[
0
]
shard_offset
=
loaded_shard_id
*
shard_size
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
# Special case for per-tensor scales in fused case.
elif
needs_scalar_to_array
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
loaded_shard_id
)
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
hidden_size
:
int
,
head_size
:
int
,
total_num_heads
:
int
,
total_num_kv_heads
:
Optional
[
int
]
=
None
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
total_num_heads
=
total_num_heads
if
total_num_kv_heads
is
None
:
total_num_kv_heads
=
total_num_heads
self
.
total_num_kv_heads
=
total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads
=
divide
(
self
.
total_num_heads
,
tp_size
)
if
tp_size
>=
self
.
total_num_kv_heads
:
self
.
num_kv_heads
=
1
self
.
num_kv_head_replicas
=
divide
(
tp_size
,
self
.
total_num_kv_heads
)
else
:
self
.
num_kv_heads
=
divide
(
self
.
total_num_kv_heads
,
tp_size
)
self
.
num_kv_head_replicas
=
1
input_size
=
self
.
hidden_size
output_size
=
(
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
tp_size
*
self
.
head_size
)
self
.
output_sizes
=
[
self
.
num_heads
*
self
.
head_size
*
tp_size
,
# q_proj
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# k_proj
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# v_proj
]
super
().
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
bias
=
bias
,
gather_output
=
False
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
,
):
if
param
.
data
.
dtype
!=
loaded_weight
.
dtype
:
param
.
data
=
torch
.
empty_like
(
param
.
data
,
dtype
=
loaded_weight
.
dtype
,
device
=
"cuda"
)
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
if
loaded_shard_id
is
None
:
# Loaded weight is already fused on disk (qkv/mlp).
if
output_dim
is
None
:
if
needs_scalar_to_array
is
not
None
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
return
shard_offsets
=
[
# (shard_id, shard_offset, shard_size)
(
"q"
,
0
,
self
.
total_num_heads
*
self
.
head_size
),
(
"k"
,
self
.
total_num_heads
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
,
),
(
"v"
,
(
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
,
),
]
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
return
tp_rank
=
get_tensor_model_parallel_rank
()
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
# If output dim is defined, use the default loading process.
if
output_dim
is
not
None
:
if
loaded_shard_id
==
"q"
:
shard_offset
=
0
shard_size
=
self
.
num_heads
*
self
.
head_size
elif
loaded_shard_id
==
"k"
:
shard_offset
=
self
.
num_heads
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
elif
loaded_shard_id
==
"v"
:
shard_offset
=
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
orig_qkv_offsets
=
{
"q"
:
(
0
,
self
.
num_heads
*
self
.
head_size
),
"k"
:
(
self
.
num_heads
*
self
.
head_size
,
self
.
num_kv_heads
*
self
.
head_size
,
),
"v"
:
(
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
,
self
.
num_kv_heads
*
self
.
head_size
,
),
"total"
:
(
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
self
.
head_size
,
0
,
),
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
shard_id
=
tp_rank
else
:
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
start_idx
=
shard_id
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for for AQLM codebooks.
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
shard_size
=
loaded_weight
.
shape
[
0
]
shard_index
=
[
"q"
,
"k"
,
"v"
].
index
(
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
0
,
shard_index
*
shard_size
,
shard_size
)
# Special case for per-tensor scales in fused case.
elif
needs_scalar_to_array
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
loaded_shard_id
)
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
RowParallelLinear
(
LinearBase
):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
input_is_parallel
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
layer
=
self
,
input_size_per_partition
=
self
.
input_size_per_partition
,
output_partition_sizes
=
[
self
.
output_size
],
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
,
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
},
)
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
if
param
.
data
.
dtype
!=
loaded_weight
.
dtype
:
param
.
data
=
torch
.
empty_like
(
param
.
data
,
dtype
=
loaded_weight
.
dtype
,
device
=
"cuda"
)
param_data
=
param
.
data
tp_rank
=
get_tensor_model_parallel_rank
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
if
input_dim
is
not
None
:
shard_size
=
param
.
data
.
shape
[
input_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output_
=
output_parallel
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"input_features=
{
self
.
input_size_per_partition
}
"
s
+=
f
", output_features=
{
self
.
output_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
self
.
tp_size
}
"
s
+=
f
", reduce_results=
{
self
.
reduce_results
}
"
return
s
python/sglang/srt/layers/quantization/__init__.py
0 → 100644
View file @
2d96da81
# temporarily adapted from vLLM
# FIXME: in progress of refactoring the model loader
from
typing
import
Dict
,
Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
,
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
GPTQMarlinConfig
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
GPTQMarlin24Config
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"fp8"
:
Fp8Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
}
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
if
quantization
not
in
QUANTIZATION_METHODS
:
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
return
QUANTIZATION_METHODS
[
quantization
]
__all__
=
[
"QuantizationConfig"
,
"get_quantization_config"
,
"QUANTIZATION_METHODS"
,
]
python/sglang/srt/layers/quantization/fp8.py
0 → 100644
View file @
2d96da81
# adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/quantization/fp8.py
# FIXME refactor in progress
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
,
fused_moe
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQMarlinState
,
marlin_permute_scales
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
pack_fp8_to_int32
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
logger
=
init_logger
(
__name__
)
def
cutlass_fp8_supported
()
->
bool
:
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change."
)
if
activation_scheme
not
in
ACTIVATION_SCHEMES
:
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"fp8"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"Fp8Config"
:
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_fp8_serialized
=
"fp8"
in
quant_method
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
,
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
Fp8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
Fp8LinearMethod
(
LinearMethodBase
):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
def
_create_scale_param
(
self
,
scale_name
:
str
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
,
)
->
None
:
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
layer
.
register_parameter
(
scale_name
,
scale
)
set_weight_attrs
(
scale
,
{
**
extra_weight_attrs
,
"needs_scalar_to_array"
:
True
,
},
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
process_after_load
=
True
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
# weight_dtype = (torch.float8_e4m3fn
# if self.quant_config.is_checkpoint_fp8_serialized else
# params_dtype)
weight_dtype
=
torch
.
float8_e4m3fn
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
weight_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
**
extra_weight_attrs
,
"input_dim"
:
1
,
"output_dim"
:
0
,
},
)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
self
.
_create_scale_param
(
scale_name
=
"weight_scale"
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
,
)
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
self
.
_create_scale_param
(
scale_name
=
"input_scale"
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
,
)
# For GPUs without FP8 hardware support, we use Marlin for fast
# fused dequantization
if
self
.
use_marlin
:
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
prepare_layer_for_marlin
(
self
,
layer
:
Module
)
->
None
:
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
assert
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
layer
.
marlin_state
=
GPTQMarlinState
.
READY
device
=
layer
.
weight
.
device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
layer
.
weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
,
)
layer
.
weight
=
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
(
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
layer
.
orig_dtype
).
to
(
device
)
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=-
1
,
num_bits
=
8
,
)
layer
.
weight_scale
=
Parameter
(
marlin_scales
,
requires_grad
=
False
)
# Allocate marlin workspace
max_workspace_size
=
(
part_size_n
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
not
hasattr
(
layer
,
"process_after_load"
)
or
not
layer
.
process_after_load
:
return
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
logical_widths
=
None
layer
.
input_scale
=
None
if
self
.
use_marlin
:
self
.
prepare_layer_for_marlin
(
layer
)
return
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
else
:
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
max_w_scale
=
layer
.
weight_scale
.
max
()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. As a result, we skip dequant -> requant
# since we already have quantized QKV together.
# Sample Model with fused checkpoint:
# * nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint
=
(
layer
.
weight_scale
[
-
1
]
>
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
)
if
unfused_module_in_checkpoint
:
start
=
0
for
idx
,
logical_width
in
enumerate
(
layer
.
logical_widths
):
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
layer
.
weight
[
start
:
end
,
:],
layer
.
weight_scale
[
idx
]
)
layer
.
weight
[
start
:
end
,
:]
=
per_tensor_quantize
(
weight_dq
,
layer
.
weight_scale
.
max
()
)
start
=
end
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# WEIGHT
# Transpose weight for passing to torch._scaled_mm
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# INPUT ACTIVATION SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the input_scales (since they are equal).
if
self
.
quant_config
.
activation_scheme
==
"dynamic"
:
layer
.
input_scale
=
None
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
raise
ValueError
(
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
if
self
.
use_marlin
:
self
.
prepare_layer_for_marlin
(
layer
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
use_marlin
:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
layer
.
output_size_per_partition
,)
output
=
ops
.
fp8_marlin_gemm
(
a
=
reshaped_x
,
b_q_weight
=
layer
.
weight
,
b_scales
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
num_bits
=
8
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
else
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x
# If static, layer.input_scale is scalar and x_scale is input_scale
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
,
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
bias
=
bias
,
)
return
torch
.
narrow
(
output
,
0
,
0
,
x
.
shape
[
0
])
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
process_after_load
=
True
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scale"
,
w13_scale
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
a13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a13_scale"
,
a13_scale
)
set_weight_attrs
(
a13_scale
,
extra_weight_attrs
)
a2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a2_scale"
,
a2_scale
)
set_weight_attrs
(
a2_scale
,
extra_weight_attrs
)
else
:
layer
.
a13_scale
=
None
layer
.
a2_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
not
hasattr
(
layer
,
"process_after_load"
)
or
not
layer
.
process_after_load
:
return
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
),
requires_grad
=
False
,
)
for
expert
in
range
(
layer
.
num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:]
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
layer
.
a13_scale
is
None
or
layer
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
not
all_close_1d
(
layer
.
a13_scale
)
or
not
all_close_1d
(
layer
.
a2_scale
):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
a13_scale
=
torch
.
nn
.
Parameter
(
layer
.
a13_scale
.
max
(),
requires_grad
=
False
)
layer
.
a2_scale
=
torch
.
nn
.
Parameter
(
layer
.
a2_scale
.
max
(),
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_scale
[
expert_id
][
shard_id
],
)
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:]
=
(
per_tensor_quantize
(
dq_weight
,
max_w13_scales
[
expert_id
])
)
start
+=
shard_size
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
return
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
)
->
torch
.
Tensor
:
return
fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_fp8
=
True
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
,
a1_scale
=
layer
.
a13_scale
,
a2_scale
=
layer
.
a2_scale
,
)
# FIXME: not used
class
Fp8KVCacheMethod
(
QuantizeMethodBase
):
"""Supports loading kv-cache scaling factors from FP8 checkpoints."""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer
.
kv_scale
=
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
raise
RuntimeError
(
"Fp8KVCacheMethod.apply should not be called."
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if
layer
.
kv_cache_dtype
!=
"auto"
:
kv_scale
=
layer
.
kv_scale
.
to
(
"cpu"
).
tolist
()
if
not
isinstance
(
kv_scale
,
float
):
raise
ValueError
(
"Only support per-tensor scaling factor "
"for fp8 KV cache"
)
layer
.
_kv_scale
=
kv_scale
if
layer
.
_kv_scale
==
1.0
and
"e5m2"
not
in
layer
.
kv_cache_dtype
:
print_warning_once
(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint."
)
del
layer
.
kv_scale
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
]
)
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
qweight
=
(
tensor
/
inv_scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
qweight
.
to
(
torch
.
float8_e4m3fn
)
def
per_tensor_dequantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
]
)
->
torch
.
Tensor
:
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
python/sglang/srt/model_loader/model_loader.py
0 → 100644
View file @
2d96da81
# temporarily adapted from https://github.com/vllm-project/vllm/blob/10383887e03412196a2689b9398290719c4797bf/vllm/model_executor/model_loader/loader.py
# FIXME: in progress of refactoring the model loader
import
glob
import
os
import
re
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
torch
import
nn
from
tqdm
import
tqdm
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
SchedulerConfig
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.utils
import
(
get_model_architecture
,
set_default_torch_dtype
,
)
from
vllm.platforms
import
current_platform
from
sglang.srt.model_loader.utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
get_quant_config
,
safetensors_weights_iterator
,
)
def
_get_quantization_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
Optional
[
QuantizationConfig
]:
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
is not "
"supported for the current GPU. "
f
"Minimum capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
if
model_config
.
dtype
not
in
supported_dtypes
:
raise
ValueError
(
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
return
quant_config
return
None
def
_get_model_initialization_kwargs
(
model_class
:
Type
[
nn
.
Module
],
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
)
->
Dict
[
str
,
Any
]:
"""Get extra kwargs for model initialization."""
extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
assert
lora_config
is
None
assert
multimodal_config
is
None
return
extra_kwargs
def
_initialize_model
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_class
=
get_model_architecture
(
model_config
)[
0
]
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
return
model_class
(
config
=
model_config
.
hf_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
**
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
multimodal_config
),
)
class
ModelLoader
:
"""Model loader that can load different file types from disk."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
self
.
load_config
=
load_config
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
allow_patterns
=
[
"*.bin"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
allow_patterns
+=
[
"*.pt"
]
if
not
is_local
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
use_safetensors
:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
revision
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
_get_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
model_name_or_path
,
revision
,
fall_back_to_pt
)
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
weights_iterator
=
np_cache_weights_iterator
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
,
)
elif
use_safetensors
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
return
weights_iterator
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
cache_config
,
)
weights
=
self
.
_get_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
),
)
modules
=
{}
for
name
,
module
in
model
.
named_modules
():
modules
[
name
]
=
module
def
apply_quant_method
(
module
):
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
# print("before apply quant", module.weight, module.weight.dtype)
quant_method
.
process_weights_after_loading
(
module
)
# print("after apply quant", module.weight, module.weight.dtype)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if
hasattr
(
module
,
"process_weights_after_loading"
):
module
.
process_weights_after_loading
()
if
torch
.
cuda
.
current_device
()
==
0
:
weights
=
tqdm
(
weights
,
total
=
model
.
get_num_params
()
*
1.5
,
desc
=
"load model"
)
num_shard
=
{}
num_loaded
=
{}
for
name
,
loaded_weight
in
weights
:
model
.
load_weights
(
None
,
name
,
loaded_weight
)
module_name
,
shard_num
=
model
.
get_module_name
(
name
)
num_shard
[
module_name
]
=
shard_num
if
module_name
not
in
num_loaded
:
num_loaded
[
module_name
]
=
1
else
:
num_loaded
[
module_name
]
+=
1
if
num_loaded
[
module_name
]
==
num_shard
[
module_name
]:
apply_quant_method
(
modules
[
module_name
])
return
model
.
eval
()
def
get_model
(
*
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
device_config
:
DeviceConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
loader
=
ModelLoader
(
load_config
)
return
loader
.
load_model
(
model_config
=
model_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
multimodal_config
=
multimodal_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
cache_config
=
cache_config
,
)
python/sglang/srt/model_loader/utils.py
0 → 100644
View file @
2d96da81
# temporarily adapted from vLLM
# FIXME: in progress of refactoring the model loader
"""Utilities for selecting and loading models."""
import
contextlib
import
fnmatch
import
hashlib
import
json
import
logging
import
os
import
tempfile
from
typing
import
Any
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
import
filelock
import
huggingface_hub.constants
import
torch
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
torch
import
nn
from
tqdm.auto
import
tqdm
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization
import
get_quantization_config
logger
=
logging
.
getLogger
(
"srt.model_loader"
)
temp_dir
=
tempfile
.
gettempdir
()
@
contextlib
.
contextmanager
def
set_default_torch_dtype
(
dtype
:
torch
.
dtype
):
"""Sets the default torch dtype to the given dtype."""
old_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
yield
torch
.
set_default_dtype
(
old_dtype
)
def
get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
and
model_config
.
quantization
!=
"fp8"
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
for
arch
in
architectures
:
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
(
model_cls
,
arch
)
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
class
DisabledTqdm
(
tqdm
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
,
disable
=
True
)
def
get_lock
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
):
lock_dir
=
cache_dir
or
temp_dir
os
.
makedirs
(
os
.
path
.
dirname
(
lock_dir
),
exist_ok
=
True
)
model_name
=
model_name_or_path
.
replace
(
"/"
,
"-"
)
hash_name
=
hashlib
.
sha256
(
model_name
.
encode
()).
hexdigest
()
# add hash to avoid conflict with old users' lock files
lock_file_name
=
hash_name
+
model_name
+
".lock"
# mode 0o666 is required for the filelock to be shared across users
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
),
mode
=
0o666
)
return
lock
def
download_weights_from_hf
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
allow_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
str
:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (List[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
Returns:
str: The path to the downloaded model weights.
"""
if
not
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
:
# Before we download we look at that is available:
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
# depending on what is available we download different things
for
pattern
in
allow_patterns
:
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
if
len
(
matching
)
>
0
:
allow_patterns
=
[
pattern
]
break
logger
.
info
(
"Using model weights format %s"
,
allow_patterns
)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
DisabledTqdm
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
return
hf_folder
def
download_safetensors_index_file_from_hf
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
None
:
"""Download hf safetensors index file from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
try
:
# Download the safetensors index file.
hf_hub_download
(
repo_id
=
model_name_or_path
,
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
cache_dir
=
cache_dir
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
# If file not found on remote or locally, we should not fail since
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
logger
.
info
(
"No %s found in remote."
,
SAFE_WEIGHTS_INDEX_NAME
)
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
logger
.
info
(
"No %s found in local cache."
,
SAFE_WEIGHTS_INDEX_NAME
)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
# look up which safetensors files should be used.
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
hf_folder
:
str
)
->
List
[
str
]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name
=
os
.
path
.
join
(
hf_folder
,
SAFE_WEIGHTS_INDEX_NAME
)
if
not
os
.
path
.
isfile
(
index_file_name
):
return
hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with
open
(
index_file_name
)
as
index_file
:
weight_map
=
json
.
load
(
index_file
)[
"weight_map"
]
weight_files_in_index
=
set
()
for
weight_name
in
weight_map
:
weight_files_in_index
.
add
(
os
.
path
.
join
(
hf_folder
,
weight_map
[
weight_name
]))
# Filter out any fields that are not found in the index file.
hf_weights_files
=
[
f
for
f
in
hf_weights_files
if
f
in
weight_files_in_index
]
return
hf_weights_files
def
safetensors_weights_iterator
(
hf_weights_files
:
List
[
str
],
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model safetensor files."""
for
st_file
in
hf_weights_files
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
param
=
f
.
get_tensor
(
name
)
yield
name
,
param
def
get_quant_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
is
None
:
# compressed-tensors uses a compressions_config
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"compression_config"
,
None
)
if
hf_quant_config
is
not
None
:
return
quant_cls
.
from_config
(
hf_quant_config
)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
if
model_config
.
quantization
==
"bitsandbytes"
:
if
(
not
load_config
.
model_loader_extra_config
or
"qlora_adapter_name_or_path"
not
in
load_config
.
model_loader_extra_config
):
return
quant_cls
.
from_config
({
"adapter_name_or_path"
:
""
})
model_name_or_path
=
load_config
.
model_loader_extra_config
[
"qlora_adapter_name_or_path"
]
else
:
model_name_or_path
=
model_config
.
model
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
not
is_local
:
# Download the config files.
with
get_lock
(
model_name_or_path
,
load_config
.
download_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
revision
=
model_config
.
revision
,
allow_patterns
=
"*.json"
,
cache_dir
=
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
tqdm_class
=
DisabledTqdm
,
)
else
:
hf_folder
=
model_name_or_path
possible_config_filenames
=
quant_cls
.
get_config_filenames
()
# If the quantization config is not found, use the default config.
if
not
possible_config_filenames
:
return
quant_cls
()
config_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.json"
))
quant_config_files
=
[
f
for
f
in
config_files
if
any
(
f
.
endswith
(
x
)
for
x
in
possible_config_filenames
)
]
if
len
(
quant_config_files
)
==
0
:
raise
ValueError
(
f
"Cannot find the config file for
{
model_config
.
quantization
}
"
)
if
len
(
quant_config_files
)
>
1
:
raise
ValueError
(
f
"Found multiple config files for
{
model_config
.
quantization
}
: "
f
"
{
quant_config_files
}
"
)
quant_config_file
=
quant_config_files
[
0
]
with
open
(
quant_config_file
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
if
model_config
.
quantization
==
"bitsandbytes"
:
config
[
"adapter_name_or_path"
]
=
model_name_or_path
return
quant_cls
.
from_config
(
config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment