Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b4408b0d
Unverified
Commit
b4408b0d
authored
Sep 19, 2024
by
Yineng Zhang
Committed by
GitHub
Sep 19, 2024
Browse files
feat: update linear deps 1/N (#1305)
parent
2cd7e181
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1414 additions
and
82 deletions
+1414
-82
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+3
-2
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+1133
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+76
-0
python/sglang/srt/layers/quantization/base_config.py
python/sglang/srt/layers/quantization/base_config.py
+122
-0
python/sglang/srt/models/baichuan.py
python/sglang/srt/models/baichuan.py
+1
-1
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+6
-6
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+7
-7
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+7
-7
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+7
-7
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+7
-7
python/sglang/srt/models/exaone.py
python/sglang/srt/models/exaone.py
+6
-6
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+6
-6
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+6
-6
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+6
-6
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+6
-6
python/sglang/srt/models/internlm2.py
python/sglang/srt/models/internlm2.py
+6
-6
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+6
-6
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+1
-1
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+1
-1
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+1
-1
No files found.
python/sglang/srt/layers/activation.py
View file @
b4408b0d
...
...
@@ -31,8 +31,9 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.utils
import
set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/linear.py
0 → 100644
View file @
b4408b0d
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/linear.py
import
logging
from
abc
import
abstractmethod
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
)
# workaround
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
PackedvLLMParameter
,
PerTensorScaleParameter
,
)
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.utils
import
set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
]
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
marlin_tile_size
=
getattr
(
param
,
"marlin_tile_size"
,
None
)
if
marlin_tile_size
is
None
:
return
shard_size
,
shard_offset
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
def
adjust_bitsandbytes_shard
(
param
:
Parameter
,
qkv_offsets
:
Dict
[
str
,
Tuple
[
int
,
int
]],
loaded_shard_id
:
str
)
->
Tuple
[
int
,
int
]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total
,
_
=
qkv_offsets
[
"total"
]
orig_offset
,
orig_size
=
qkv_offsets
[
loaded_shard_id
]
quantized_total
=
param
.
data
.
shape
[
0
]
quantized_offset
=
orig_offset
*
quantized_total
//
total
quantized_size
=
orig_size
*
quantized_total
//
total
return
quantized_size
,
quantized_offset
def
adjust_scalar_to_fused_array
(
param
,
loaded_weight
,
shard_id
):
"""For fused modules (QKV and MLP) we have an array of length
N that holds 1 scale for each "logical" matrix. So the param
is an array of length N. The loaded_weight corresponds to
one of the shards on disk. Here, we slice the param based on
the shard_id for loading.
"""
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
if
isinstance
(
shard_id
,
str
):
shard_id
=
qkv_idxs
[
shard_id
]
elif
not
isinstance
(
shard_id
,
int
):
raise
ValueError
(
f
"Unknown Shard Id
{
shard_id
}
"
)
# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if
len
(
loaded_weight
.
shape
)
!=
0
:
assert
loaded_weight
.
shape
[
0
]
==
1
loaded_weight
=
loaded_weight
[
0
]
return
param
[
shard_id
],
loaded_weight
class
LinearMethodBase
(
QuantizeMethodBase
):
"""Base class for different (maybe quantized) linear methods."""
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
class
UnquantizedLinearMethod
(
LinearMethodBase
):
"""Linear method without quantization."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
class
ReplicatedLinear
(
LinearBase
):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
=
prefix
,
)
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
,
prefix
=
prefix
,
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
)
)
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
},
)
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", output_features=
{
self
.
output_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
return
s
class
ColumnParallelLinear
(
LinearBase
):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
)
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
quant_method
is
not
None
self
.
output_size_per_partition
=
divide
(
self
.
output_size
,
tp_size
)
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
# If QKV or MergedColumn, use output size of each partition.
if
hasattr
(
self
,
"output_sizes"
):
self
.
output_partition_sizes
=
[
divide
(
output_size
,
tp_size
)
for
output_size
in
self
.
output_sizes
]
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
self
.
quant_method
.
create_weights
(
layer
=
self
,
input_size_per_partition
=
self
.
input_size
,
output_partition_sizes
=
self
.
output_partition_sizes
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
(
self
.
weight_loader_v2
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
else
self
.
weight_loader
),
prefix
=
prefix
,
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
)
)
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
},
)
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight_type
=
getattr
(
param
,
"is_gguf_weight_type"
,
False
)
if
is_gguf_weight_type
:
param
.
weight_type
=
loaded_weight
.
item
()
# Materialize GGUF UninitializedParameter
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
param
.
materialize
(
loaded_weight
.
shape
,
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
if
output_dim
is
not
None
:
shard_size
=
param_data
.
shape
[
output_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
weight_loader_v2
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
assert
loaded_weight
.
numel
()
==
1
loaded_weight
=
loaded_weight
.
reshape
(
1
)
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", output_features=
{
self
.
output_size_per_partition
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
s
+=
f
", gather_output=
{
self
.
gather_output
}
"
return
s
class
MergedColumnParallelLinear
(
ColumnParallelLinear
):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make the output
available to all GPUs, otherwise, every GPU will have
its own output.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def
__init__
(
self
,
input_size
:
int
,
output_sizes
:
List
[
int
],
bias
:
bool
=
True
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
super
().
__init__
(
input_size
=
input_size
,
output_size
=
sum
(
output_sizes
),
bias
=
bias
,
gather_output
=
gather_output
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
,
):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight_type
=
getattr
(
param
,
"is_gguf_weight_type"
,
False
)
if
is_gguf_weight_type
:
param
.
data
[
loaded_shard_id
].
copy_
(
loaded_weight
)
param
.
shard_weight_type
[
loaded_shard_id
]
=
loaded_weight
.
item
()
return
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
from
gguf.constants
import
GGML_QUANT_SIZES
ori_shape
=
param
.
tensor_shape
weight_types
=
self
.
qweight_type
.
shard_weight_type
.
values
()
row_size
=
[]
for
weight_type
in
weight_types
:
block_size
,
type_size
=
GGML_QUANT_SIZES
[
weight_type
]
row_size
.
append
(
ori_shape
[
1
]
//
block_size
*
type_size
)
q_shape
=
(
ori_shape
[
0
],
max
(
row_size
))
param
.
materialize
(
q_shape
,
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
if
loaded_shard_id
is
None
:
# Loaded weight is already fused on disk (qkv/mlp).
if
output_dim
is
None
:
if
needs_scalar_to_array
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
return
current_shard_offset
=
0
shard_offsets
:
List
[
Tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
return
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
output_dim
is
not
None
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
shard_size
=
loaded_weight
.
shape
[
output_dim
]
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
loaded_shard_id
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_shape
=
list
(
loaded_weight
.
shape
)
shard_shape
[
output_dim
]
=
shard_shape
[
output_dim
]
//
tp_size
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_size
[
loaded_shard_id
]
=
shard_shape
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_size
=
loaded_weight
.
shape
[
input_dim
]
param_data
=
param_data
.
narrow
(
input_dim
,
0
,
input_size
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for AQLM codebooks.
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
shard_size
=
loaded_weight
.
shape
[
0
]
shard_offset
=
loaded_shard_id
*
shard_size
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
# Special case for per-tensor scales in fused case.
elif
needs_scalar_to_array
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
loaded_shard_id
)
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
_load_fused_module_from_checkpoint
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
):
"""
Handle special case for models where MLP layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
current_shard_offset
=
0
shard_offsets
:
List
[
Tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
(
isinstance
(
param
,
PackedvLLMParameter
)
and
param
.
packed_dim
==
param
.
output_dim
):
shard_size
,
shard_offset
=
param
.
adjust_shard_indexes_for_packing
(
shard_size
=
shard_size
,
shard_offset
=
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
param
.
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader_v2
(
param
,
loaded_weight_shard
,
shard_id
)
def
weight_loader_v2
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
,
):
if
loaded_shard_id
is
None
:
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
elif
type
(
param
)
is
BasevLLMParameter
:
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
return
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
return
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_size
=
get_tensor_model_parallel_world_size
()
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def
__init__
(
self
,
hidden_size
:
int
,
head_size
:
int
,
total_num_heads
:
int
,
total_num_kv_heads
:
Optional
[
int
]
=
None
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
total_num_heads
=
total_num_heads
if
total_num_kv_heads
is
None
:
total_num_kv_heads
=
total_num_heads
self
.
total_num_kv_heads
=
total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads
=
divide
(
self
.
total_num_heads
,
tp_size
)
if
tp_size
>=
self
.
total_num_kv_heads
:
self
.
num_kv_heads
=
1
self
.
num_kv_head_replicas
=
divide
(
tp_size
,
self
.
total_num_kv_heads
)
else
:
self
.
num_kv_heads
=
divide
(
self
.
total_num_kv_heads
,
tp_size
)
self
.
num_kv_head_replicas
=
1
input_size
=
self
.
hidden_size
output_size
=
(
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
tp_size
*
self
.
head_size
)
self
.
output_sizes
=
[
self
.
num_heads
*
self
.
head_size
*
tp_size
,
# q_proj
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# k_proj
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# v_proj
]
super
().
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
bias
=
bias
,
gather_output
=
False
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
def
_get_shard_offset_mapping
(
self
,
loaded_shard_id
:
str
):
shard_offset_mapping
=
{
"q"
:
0
,
"k"
:
self
.
num_heads
*
self
.
head_size
,
"v"
:
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
,
"total"
:
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
self
.
head_size
,
}
return
shard_offset_mapping
.
get
(
loaded_shard_id
)
def
_get_shard_size_mapping
(
self
,
loaded_shard_id
:
str
):
shard_size_mapping
=
{
"q"
:
self
.
num_heads
*
self
.
head_size
,
"k"
:
self
.
num_kv_heads
*
self
.
head_size
,
"v"
:
self
.
num_kv_heads
*
self
.
head_size
,
}
return
shard_size_mapping
.
get
(
loaded_shard_id
)
def
_load_fused_module_from_checkpoint
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
):
"""
Handle special case for models where QKV layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
shard_offsets
=
[
# (shard_id, shard_offset, shard_size)
(
"q"
,
0
,
self
.
total_num_heads
*
self
.
head_size
),
(
"k"
,
self
.
total_num_heads
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
,
),
(
"v"
,
(
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
,
),
]
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
(
isinstance
(
param
,
PackedvLLMParameter
)
and
param
.
packed_dim
==
param
.
output_dim
):
shard_size
,
shard_offset
=
param
.
adjust_shard_indexes_for_packing
(
shard_size
=
shard_size
,
shard_offset
=
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
param
.
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader_v2
(
param
,
loaded_weight_shard
,
shard_id
)
def
weight_loader_v2
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
,
):
if
loaded_shard_id
is
None
:
# special case for certain models
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
elif
type
(
param
)
is
BasevLLMParameter
:
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
return
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
return
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
shard_offset
=
self
.
_get_shard_offset_mapping
(
loaded_shard_id
)
shard_size
=
self
.
_get_shard_size_mapping
(
loaded_shard_id
)
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
num_heads
=
self
.
num_kv_head_replicas
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
,
):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight_type
=
getattr
(
param
,
"is_gguf_weight_type"
,
False
)
if
is_gguf_weight_type
and
loaded_shard_id
is
not
None
:
idx_map
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
param
.
data
[
idx_map
[
loaded_shard_id
]].
copy_
(
loaded_weight
)
param
.
shard_weight_type
[
loaded_shard_id
]
=
loaded_weight
.
item
()
return
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
from
gguf.constants
import
GGML_QUANT_SIZES
ori_shape
=
param
.
tensor_shape
weight_types
=
self
.
qweight_type
.
shard_weight_type
.
values
()
row_size
=
[]
for
weight_type
in
weight_types
:
block_size
,
type_size
=
GGML_QUANT_SIZES
[
weight_type
]
row_size
.
append
(
ori_shape
[
1
]
//
block_size
*
type_size
)
q_shape
=
(
ori_shape
[
0
],
max
(
row_size
))
param
.
materialize
(
q_shape
,
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
if
loaded_shard_id
is
None
:
# Loaded weight is already fused on disk (qkv/mlp).
if
output_dim
is
None
:
if
needs_scalar_to_array
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
return
shard_offsets
=
[
# (shard_id, shard_offset, shard_size)
(
"q"
,
0
,
self
.
total_num_heads
*
self
.
head_size
),
(
"k"
,
self
.
total_num_heads
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
,
),
(
"v"
,
(
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
,
),
]
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
return
tp_rank
=
get_tensor_model_parallel_rank
()
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
# If output dim is defined, use the default loading process.
if
output_dim
is
not
None
:
if
loaded_shard_id
==
"q"
:
shard_offset
=
0
shard_size
=
self
.
num_heads
*
self
.
head_size
elif
loaded_shard_id
==
"k"
:
shard_offset
=
self
.
num_heads
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
elif
loaded_shard_id
==
"v"
:
shard_offset
=
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
orig_qkv_offsets
=
{
"q"
:
(
0
,
self
.
num_heads
*
self
.
head_size
),
"k"
:
(
self
.
num_heads
*
self
.
head_size
,
self
.
num_kv_heads
*
self
.
head_size
,
),
"v"
:
(
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
,
self
.
num_kv_heads
*
self
.
head_size
,
),
"total"
:
(
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
self
.
head_size
,
0
,
),
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_shape
=
list
(
loaded_weight
.
shape
)
shard_shape
[
output_dim
]
=
shard_shape
[
output_dim
]
//
tp_size
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_size
[
loaded_shard_id
]
=
shard_shape
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_size
=
loaded_weight
.
shape
[
input_dim
]
param_data
=
param_data
.
narrow
(
input_dim
,
0
,
input_size
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
shard_id
=
tp_rank
else
:
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
start_idx
=
shard_id
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for for AQLM codebooks.
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
shard_size
=
loaded_weight
.
shape
[
0
]
shard_index
=
[
"q"
,
"k"
,
"v"
].
index
(
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
0
,
shard_index
*
shard_size
,
shard_size
)
# Special case for per-tensor scales in fused case.
elif
needs_scalar_to_array
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
loaded_shard_id
)
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
RowParallelLinear
(
LinearBase
):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
input_is_parallel
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
)
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
# Divide the weight matrix along the last dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
layer
=
self
,
input_size_per_partition
=
self
.
input_size_per_partition
,
output_partition_sizes
=
[
self
.
output_size
],
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
(
self
.
weight_loader_v2
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
else
self
.
weight_loader
),
prefix
=
prefix
,
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
},
)
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight_type
=
getattr
(
param
,
"is_gguf_weight_type"
,
False
)
if
is_gguf_weight_type
:
param
.
weight_type
=
loaded_weight
.
item
()
# Materialize GGUF UninitializedParameter
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
weight_shape
=
list
(
loaded_weight
.
shape
)
if
input_dim
:
weight_shape
[
input_dim
]
=
weight_shape
[
input_dim
]
//
tp_size
param
.
materialize
(
tuple
(
weight_shape
),
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
if
input_dim
is
not
None
:
shard_size
=
param_data
.
shape
[
input_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
weight_loader_v2
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
assert
loaded_weight
.
numel
()
==
1
loaded_weight
=
loaded_weight
.
reshape
(
1
)
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
self
,
input_
):
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
assert
self
.
quant_method
is
not
None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"input_features=
{
self
.
input_size_per_partition
}
"
s
+=
f
", output_features=
{
self
.
output_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
self
.
tp_size
}
"
s
+=
f
", reduce_results=
{
self
.
reduce_results
}
"
return
s
python/sglang/srt/layers/quantization/__init__.py
0 → 100644
View file @
b4408b0d
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from
typing
import
Dict
,
Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
,
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.experts_int8
import
ExpertsInt8Config
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
GPTQMarlinConfig
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
GPTQMarlin24Config
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"tpu_int8"
:
Int8TpuConfig
,
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"gguf"
:
GGUFConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
}
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
if
quantization
not
in
QUANTIZATION_METHODS
:
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
return
QUANTIZATION_METHODS
[
quantization
]
__all__
=
[
"QuantizationConfig"
,
"get_quantization_config"
,
"QUANTIZATION_METHODS"
,
]
"""
def fp8_get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
return None
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
"""
python/sglang/srt/layers/quantization/base_config.py
0 → 100644
View file @
b4408b0d
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch
import
nn
class
QuantizeMethodBase
(
ABC
):
"""Base class for different quantized methods."""
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
*
weight_args
,
**
extra_weight_attrs
):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
def
process_weights_after_loading
(
self
,
layer
:
nn
.
Module
)
->
None
:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class
QuantizationConfig
(
ABC
):
"""Base class for quantization configs."""
@
abstractmethod
def
get_name
(
self
)
->
str
:
"""Name of the quantization method."""
raise
NotImplementedError
@
abstractmethod
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
"""List of supported activation dtypes."""
raise
NotImplementedError
@
classmethod
@
abstractmethod
def
get_min_capability
(
cls
)
->
int
:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
get_config_filenames
()
->
List
[
str
]:
"""List of filenames to search for in the model directory."""
raise
NotImplementedError
@
classmethod
@
abstractmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"QuantizationConfig"
:
"""Create a config class from the model's quantization config."""
raise
NotImplementedError
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --
this method should only be overwritten by subclasses in exceptional
circumstances
"""
return
None
@
staticmethod
def
get_from_keys
(
config
:
Dict
[
str
,
Any
],
keys
:
List
[
str
])
->
Any
:
"""Get a value from the model's quantization config."""
for
key
in
keys
:
if
key
in
config
:
return
config
[
key
]
raise
ValueError
(
f
"Cannot find any of
{
keys
}
in the model's "
"quantization config."
)
@
staticmethod
def
get_from_keys_or
(
config
:
Dict
[
str
,
Any
],
keys
:
List
[
str
],
default
:
Any
)
->
Any
:
"""Get a optional value from the model's quantization config."""
try
:
return
QuantizationConfig
.
get_from_keys
(
config
,
keys
)
except
ValueError
:
return
default
@
abstractmethod
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise
NotImplementedError
@
abstractmethod
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise
NotImplementedError
python/sglang/srt/models/baichuan.py
View file @
b4408b0d
...
...
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
...
...
@@ -45,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/models/chatglm.py
View file @
b4408b0d
...
...
@@ -24,12 +24,6 @@ from torch import nn
from
torch.nn
import
LayerNorm
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
...
...
@@ -40,7 +34,13 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/models/commandr.py
View file @
b4408b0d
...
...
@@ -50,21 +50,21 @@ from vllm.distributed import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.utils
import
set_weight_attrs
@
torch
.
compile
...
...
python/sglang/srt/models/dbrx.py
View file @
b4408b0d
...
...
@@ -27,12 +27,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
...
...
@@ -40,12 +34,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
sglang.srt.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.utils
import
set_weight_attrs
class
DbrxRouter
(
nn
.
Module
):
...
...
python/sglang/srt/models/deepseek.py
View file @
b4408b0d
...
...
@@ -28,13 +28,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
...
...
@@ -44,7 +37,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
b4408b0d
...
...
@@ -27,13 +27,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
...
...
@@ -43,7 +36,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/models/exaone.py
View file @
b4408b0d
...
...
@@ -23,12 +23,6 @@ import torch
from
torch
import
nn
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
...
...
@@ -38,7 +32,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/models/gemma.py
View file @
b4408b0d
...
...
@@ -23,19 +23,19 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/models/gemma2.py
View file @
b4408b0d
...
...
@@ -22,12 +22,6 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
...
...
@@ -35,7 +29,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/models/gpt_bigcode.py
View file @
b4408b0d
...
...
@@ -23,17 +23,17 @@ from torch import nn
from
transformers
import
GPTBigCodeConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/models/grok.py
View file @
b4408b0d
...
...
@@ -28,12 +28,6 @@ from vllm.distributed import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
...
...
@@ -44,7 +38,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.fused_moe
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/models/internlm2.py
View file @
b4408b0d
...
...
@@ -23,12 +23,6 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
...
...
@@ -38,7 +32,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/models/llama.py
View file @
b4408b0d
...
...
@@ -24,12 +24,6 @@ from torch import nn
from
transformers
import
LlamaConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
...
...
@@ -39,7 +33,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
...
python/sglang/srt/models/llama_classification.py
View file @
b4408b0d
...
...
@@ -19,10 +19,10 @@ import torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaModel
...
...
python/sglang/srt/models/llava.py
View file @
b4408b0d
...
...
@@ -32,9 +32,9 @@ from transformers import (
)
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
...
...
python/sglang/srt/models/llavavid.py
View file @
b4408b0d
...
...
@@ -23,9 +23,9 @@ from torch import nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment