Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fb1f28cb
Unverified
Commit
fb1f28cb
authored
Aug 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 12, 2024
Browse files
Clean up the comments and names under python/sglang/srt/layers (#1047)
parent
fb7421db
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
26 additions
and
1633 deletions
+26
-1633
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+2
-0
python/sglang/srt/layers/decode_attention.py
python/sglang/srt/layers/decode_attention.py
+9
-5
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+6
-1
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+2
-0
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+0
-884
python/sglang/srt/layers/prefill_attention.py
python/sglang/srt/layers/prefill_attention.py
+5
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+0
-64
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+0
-677
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-2
No files found.
python/sglang/srt/layers/activation.py
View file @
fb1f28cb
...
...
@@ -11,6 +11,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Fused operators for activation layers."""
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
python/sglang/srt/layers/
token
_attention.py
→
python/sglang/srt/layers/
decode
_attention.py
View file @
fb1f28cb
...
...
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Memory-efficient attention for decoding.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
...
...
@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
tl
.
store
(
out_ptrs
,
acc
)
def
_
token
_att_m_fwd
(
def
_
decode
_att_m_fwd
(
q
,
k_buffer
,
att_out
,
...
...
@@ -254,7 +258,7 @@ def _token_att_m_fwd(
)
def
_
token
_softmax_reducev_fwd
(
def
_
decode
_softmax_reducev_fwd
(
logics
,
v_buffer
,
o
,
...
...
@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
)
def
token
_attention_fwd
(
def
decode
_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
...
...
@@ -312,7 +316,7 @@ def token_attention_fwd(
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
REDUCE_TORCH_TYPE
,
device
=
"cuda"
)
_
token
_att_m_fwd
(
_
decode
_att_m_fwd
(
q
,
k_buffer
,
att_m
,
...
...
@@ -324,7 +328,7 @@ def token_attention_fwd(
sm_scale
,
logit_cap
,
)
_
token
_softmax_reducev_fwd
(
_
decode
_softmax_reducev_fwd
(
att_m
,
v_buffer
,
o
,
...
...
python/sglang/srt/layers/extend_attention.py
View file @
fb1f28cb
...
...
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Memory-efficient attention for prefill.
It supporst page size = 1 and prefill with KV cache (i.e. extend).
"""
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.
context_flash
attention
_nopad
import
context_attention_fwd
from
sglang.srt.layers.
prefill_
attention
import
context_attention_fwd
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
...
...
python/sglang/srt/layers/layernorm.py
View file @
fb1f28cb
...
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Fused operators for normalization layers."""
from
typing
import
Optional
,
Tuple
,
Union
import
torch
...
...
python/sglang/srt/layers/linear.py
deleted
100644 → 0
View file @
fb7421db
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# temporarily adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/linear.py
# FIXME: refactor the linear abstraction
from
abc
import
abstractmethod
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
marlin_tile_size
=
getattr
(
param
,
"marlin_tile_size"
,
None
)
if
marlin_tile_size
is
None
:
return
shard_size
,
shard_offset
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
def
adjust_bitsandbytes_shard
(
param
:
Parameter
,
qkv_offsets
:
Dict
[
str
,
Tuple
[
int
,
int
]],
loaded_shard_id
:
str
)
->
Tuple
[
int
,
int
]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total
,
_
=
qkv_offsets
[
"total"
]
orig_offset
,
orig_size
=
qkv_offsets
[
loaded_shard_id
]
quantized_total
=
param
.
data
.
shape
[
0
]
quantized_offset
=
orig_offset
*
quantized_total
//
total
quantized_size
=
orig_size
*
quantized_total
//
total
return
quantized_size
,
quantized_offset
def
adjust_scalar_to_fused_array
(
param
,
loaded_weight
,
shard_id
):
"""For fused modules (QKV and MLP) we have an array of length
N that holds 1 scale for each "logical" matrix. So the param
is an array of length N. The loaded_weight corresponds to
one of the shards on disk. Here, we slice the param based on
the shard_id for loading.
"""
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
if
isinstance
(
shard_id
,
str
):
shard_id
=
qkv_idxs
[
shard_id
]
elif
not
isinstance
(
shard_id
,
int
):
raise
ValueError
(
f
"Unknown Shard Id
{
shard_id
}
"
)
# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if
len
(
loaded_weight
.
shape
)
!=
0
:
assert
loaded_weight
.
shape
[
0
]
==
1
loaded_weight
=
loaded_weight
[
0
]
return
param
[
shard_id
],
loaded_weight
class
LinearMethodBase
(
QuantizeMethodBase
):
"""Base class for different (maybe quantized) linear methods."""
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
class
UnquantizedLinearMethod
(
LinearMethodBase
):
"""Linear method without quantization.
Args:
separate_bias_add: If true, add bias separately after matrix
multiplication.
"""
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
weight
=
layer
.
weight
if
self
.
separate_bias_add
:
if
bias
is
not
None
:
return
F
.
linear
(
x
,
weight
)
+
bias
return
F
.
linear
(
x
,
weight
)
return
F
.
linear
(
x
,
weight
,
bias
)
class
LinearBase
(
torch
.
nn
.
Module
):
"""Base linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
# Keep input parameters
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
skip_bias_add
=
skip_bias_add
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
ReplicatedLinear
(
LinearBase
):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
)
)
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", output_features=
{
self
.
output_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
return
s
class
ColumnParallelLinear
(
LinearBase
):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
quant_method
is
not
None
self
.
output_size_per_partition
=
divide
(
self
.
output_size
,
tp_size
)
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
# If QKV or MergedColumn, use output size of each partition.
if
hasattr
(
self
,
"output_sizes"
):
self
.
output_partition_sizes
=
[
divide
(
output_size
,
tp_size
)
for
output_size
in
self
.
output_sizes
]
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
self
.
quant_method
.
create_weights
(
layer
=
self
,
input_size_per_partition
=
self
.
input_size
,
output_partition_sizes
=
self
.
output_partition_sizes
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
,
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
)
)
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
},
)
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
if
param
.
data
.
dtype
!=
loaded_weight
.
dtype
:
param
.
data
=
torch
.
empty_like
(
param
.
data
,
dtype
=
loaded_weight
.
dtype
,
device
=
"cuda"
)
tp_rank
=
get_tensor_model_parallel_rank
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
param_data
=
param
.
data
if
output_dim
is
not
None
:
shard_size
=
param_data
.
shape
[
output_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", output_features=
{
self
.
output_size_per_partition
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
s
+=
f
", gather_output=
{
self
.
gather_output
}
"
return
s
class
MergedColumnParallelLinear
(
ColumnParallelLinear
):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make the output
available to all GPUs, otherwise, every GPU will have
its own output.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_sizes
:
List
[
int
],
bias
:
bool
=
True
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
super
().
__init__
(
input_size
=
input_size
,
output_size
=
sum
(
output_sizes
),
bias
=
bias
,
gather_output
=
gather_output
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
,
):
if
param
.
data
.
dtype
!=
loaded_weight
.
dtype
:
param
.
data
=
torch
.
empty_like
(
param
.
data
,
dtype
=
loaded_weight
.
dtype
,
device
=
"cuda"
)
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
if
loaded_shard_id
is
None
:
# Loaded weight is already fused on disk (qkv/mlp).
if
output_dim
is
None
:
if
needs_scalar_to_array
is
not
None
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
return
current_shard_offset
=
0
shard_offsets
:
List
[
Tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
return
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
output_dim
is
not
None
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
shard_size
=
loaded_weight
.
shape
[
output_dim
]
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
loaded_shard_id
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for AQLM codebooks.
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
shard_size
=
loaded_weight
.
shape
[
0
]
shard_offset
=
loaded_shard_id
*
shard_size
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
# Special case for per-tensor scales in fused case.
elif
needs_scalar_to_array
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
loaded_shard_id
)
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
hidden_size
:
int
,
head_size
:
int
,
total_num_heads
:
int
,
total_num_kv_heads
:
Optional
[
int
]
=
None
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
total_num_heads
=
total_num_heads
if
total_num_kv_heads
is
None
:
total_num_kv_heads
=
total_num_heads
self
.
total_num_kv_heads
=
total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads
=
divide
(
self
.
total_num_heads
,
tp_size
)
if
tp_size
>=
self
.
total_num_kv_heads
:
self
.
num_kv_heads
=
1
self
.
num_kv_head_replicas
=
divide
(
tp_size
,
self
.
total_num_kv_heads
)
else
:
self
.
num_kv_heads
=
divide
(
self
.
total_num_kv_heads
,
tp_size
)
self
.
num_kv_head_replicas
=
1
input_size
=
self
.
hidden_size
output_size
=
(
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
tp_size
*
self
.
head_size
)
self
.
output_sizes
=
[
self
.
num_heads
*
self
.
head_size
*
tp_size
,
# q_proj
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# k_proj
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# v_proj
]
super
().
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
bias
=
bias
,
gather_output
=
False
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
,
):
if
param
.
data
.
dtype
!=
loaded_weight
.
dtype
:
param
.
data
=
torch
.
empty_like
(
param
.
data
,
dtype
=
loaded_weight
.
dtype
,
device
=
"cuda"
)
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
if
loaded_shard_id
is
None
:
# Loaded weight is already fused on disk (qkv/mlp).
if
output_dim
is
None
:
if
needs_scalar_to_array
is
not
None
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
return
shard_offsets
=
[
# (shard_id, shard_offset, shard_size)
(
"q"
,
0
,
self
.
total_num_heads
*
self
.
head_size
),
(
"k"
,
self
.
total_num_heads
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
,
),
(
"v"
,
(
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
,
),
]
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
return
tp_rank
=
get_tensor_model_parallel_rank
()
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
# If output dim is defined, use the default loading process.
if
output_dim
is
not
None
:
if
loaded_shard_id
==
"q"
:
shard_offset
=
0
shard_size
=
self
.
num_heads
*
self
.
head_size
elif
loaded_shard_id
==
"k"
:
shard_offset
=
self
.
num_heads
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
elif
loaded_shard_id
==
"v"
:
shard_offset
=
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
orig_qkv_offsets
=
{
"q"
:
(
0
,
self
.
num_heads
*
self
.
head_size
),
"k"
:
(
self
.
num_heads
*
self
.
head_size
,
self
.
num_kv_heads
*
self
.
head_size
,
),
"v"
:
(
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
,
self
.
num_kv_heads
*
self
.
head_size
,
),
"total"
:
(
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
self
.
head_size
,
0
,
),
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
shard_id
=
tp_rank
else
:
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
start_idx
=
shard_id
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for for AQLM codebooks.
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
shard_size
=
loaded_weight
.
shape
[
0
]
shard_index
=
[
"q"
,
"k"
,
"v"
].
index
(
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
0
,
shard_index
*
shard_size
,
shard_size
)
# Special case for per-tensor scales in fused case.
elif
needs_scalar_to_array
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
loaded_shard_id
)
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
RowParallelLinear
(
LinearBase
):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
input_is_parallel
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
layer
=
self
,
input_size_per_partition
=
self
.
input_size_per_partition
,
output_partition_sizes
=
[
self
.
output_size
],
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
,
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
},
)
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
if
param
.
data
.
dtype
!=
loaded_weight
.
dtype
:
param
.
data
=
torch
.
empty_like
(
param
.
data
,
dtype
=
loaded_weight
.
dtype
,
device
=
"cuda"
)
param_data
=
param
.
data
tp_rank
=
get_tensor_model_parallel_rank
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
if
input_dim
is
not
None
:
shard_size
=
param
.
data
.
shape
[
input_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output_
=
output_parallel
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"input_features=
{
self
.
input_size_per_partition
}
"
s
+=
f
", output_features=
{
self
.
output_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
self
.
tp_size
}
"
s
+=
f
", reduce_results=
{
self
.
reduce_results
}
"
return
s
python/sglang/srt/layers/
context_flash
attention
_nopad
.py
→
python/sglang/srt/layers/
prefill_
attention.py
View file @
fb1f28cb
...
...
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import
torch
...
...
python/sglang/srt/layers/quantization/__init__.py
deleted
100644 → 0
View file @
fb7421db
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# temporarily adapted from vLLM
# FIXME: in progress of refactoring the model loader
from
typing
import
Dict
,
Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
,
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
GPTQMarlinConfig
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
GPTQMarlin24Config
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"fp8"
:
Fp8Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
}
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
if
quantization
not
in
QUANTIZATION_METHODS
:
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
return
QUANTIZATION_METHODS
[
quantization
]
__all__
=
[
"QuantizationConfig"
,
"get_quantization_config"
,
"QUANTIZATION_METHODS"
,
]
python/sglang/srt/layers/quantization/fp8.py
deleted
100644 → 0
View file @
fb7421db
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/quantization/fp8.py
# FIXME refactor in progress
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
,
fused_moe
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQMarlinState
,
marlin_permute_scales
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
pack_fp8_to_int32
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
logger
=
init_logger
(
__name__
)
def
cutlass_fp8_supported
()
->
bool
:
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change."
)
if
activation_scheme
not
in
ACTIVATION_SCHEMES
:
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"fp8"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"Fp8Config"
:
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_fp8_serialized
=
"fp8"
in
quant_method
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
,
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
Fp8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
Fp8LinearMethod
(
LinearMethodBase
):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
def
_create_scale_param
(
self
,
scale_name
:
str
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
,
)
->
None
:
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
layer
.
register_parameter
(
scale_name
,
scale
)
set_weight_attrs
(
scale
,
{
**
extra_weight_attrs
,
"needs_scalar_to_array"
:
True
,
},
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
process_after_load
=
True
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
# weight_dtype = (torch.float8_e4m3fn
# if self.quant_config.is_checkpoint_fp8_serialized else
# params_dtype)
weight_dtype
=
torch
.
float8_e4m3fn
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
weight_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
**
extra_weight_attrs
,
"input_dim"
:
1
,
"output_dim"
:
0
,
},
)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
self
.
_create_scale_param
(
scale_name
=
"weight_scale"
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
,
)
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
self
.
_create_scale_param
(
scale_name
=
"input_scale"
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
,
)
# For GPUs without FP8 hardware support, we use Marlin for fast
# fused dequantization
if
self
.
use_marlin
:
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
prepare_layer_for_marlin
(
self
,
layer
:
Module
)
->
None
:
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
assert
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
layer
.
marlin_state
=
GPTQMarlinState
.
READY
device
=
layer
.
weight
.
device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
layer
.
weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
,
)
layer
.
weight
=
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
(
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
layer
.
orig_dtype
).
to
(
device
)
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=-
1
,
num_bits
=
8
,
)
layer
.
weight_scale
=
Parameter
(
marlin_scales
,
requires_grad
=
False
)
# Allocate marlin workspace
max_workspace_size
=
(
part_size_n
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
not
hasattr
(
layer
,
"process_after_load"
)
or
not
layer
.
process_after_load
:
return
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
logical_widths
=
None
layer
.
input_scale
=
None
if
self
.
use_marlin
:
self
.
prepare_layer_for_marlin
(
layer
)
return
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
else
:
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
max_w_scale
=
layer
.
weight_scale
.
max
()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. As a result, we skip dequant -> requant
# since we already have quantized QKV together.
# Sample Model with fused checkpoint:
# * nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint
=
(
layer
.
weight_scale
[
-
1
]
>
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
)
if
unfused_module_in_checkpoint
:
start
=
0
for
idx
,
logical_width
in
enumerate
(
layer
.
logical_widths
):
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
layer
.
weight
[
start
:
end
,
:],
layer
.
weight_scale
[
idx
]
)
layer
.
weight
[
start
:
end
,
:]
=
per_tensor_quantize
(
weight_dq
,
layer
.
weight_scale
.
max
()
)
start
=
end
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# WEIGHT
# Transpose weight for passing to torch._scaled_mm
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# INPUT ACTIVATION SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the input_scales (since they are equal).
if
self
.
quant_config
.
activation_scheme
==
"dynamic"
:
layer
.
input_scale
=
None
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
raise
ValueError
(
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
if
self
.
use_marlin
:
self
.
prepare_layer_for_marlin
(
layer
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
use_marlin
:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
layer
.
output_size_per_partition
,)
output
=
ops
.
fp8_marlin_gemm
(
a
=
reshaped_x
,
b_q_weight
=
layer
.
weight
,
b_scales
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
num_bits
=
8
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
else
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x
# If static, layer.input_scale is scalar and x_scale is input_scale
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
,
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
bias
=
bias
,
)
return
torch
.
narrow
(
output
,
0
,
0
,
x
.
shape
[
0
])
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
process_after_load
=
True
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scale"
,
w13_scale
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
a13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a13_scale"
,
a13_scale
)
set_weight_attrs
(
a13_scale
,
extra_weight_attrs
)
a2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a2_scale"
,
a2_scale
)
set_weight_attrs
(
a2_scale
,
extra_weight_attrs
)
else
:
layer
.
a13_scale
=
None
layer
.
a2_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
not
hasattr
(
layer
,
"process_after_load"
)
or
not
layer
.
process_after_load
:
return
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
),
requires_grad
=
False
,
)
for
expert
in
range
(
layer
.
num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:]
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
layer
.
a13_scale
is
None
or
layer
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
not
all_close_1d
(
layer
.
a13_scale
)
or
not
all_close_1d
(
layer
.
a2_scale
):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
a13_scale
=
torch
.
nn
.
Parameter
(
layer
.
a13_scale
.
max
(),
requires_grad
=
False
)
layer
.
a2_scale
=
torch
.
nn
.
Parameter
(
layer
.
a2_scale
.
max
(),
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_scale
[
expert_id
][
shard_id
],
)
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:]
=
(
per_tensor_quantize
(
dq_weight
,
max_w13_scales
[
expert_id
])
)
start
+=
shard_size
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
return
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
)
->
torch
.
Tensor
:
return
fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_fp8
=
True
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
,
a1_scale
=
layer
.
a13_scale
,
a2_scale
=
layer
.
a2_scale
,
)
# FIXME: not used
class
Fp8KVCacheMethod
(
QuantizeMethodBase
):
"""Supports loading kv-cache scaling factors from FP8 checkpoints."""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer
.
kv_scale
=
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
raise
RuntimeError
(
"Fp8KVCacheMethod.apply should not be called."
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if
layer
.
kv_cache_dtype
!=
"auto"
:
kv_scale
=
layer
.
kv_scale
.
to
(
"cpu"
).
tolist
()
if
not
isinstance
(
kv_scale
,
float
):
raise
ValueError
(
"Only support per-tensor scaling factor "
"for fp8 KV cache"
)
layer
.
_kv_scale
=
kv_scale
if
layer
.
_kv_scale
==
1.0
and
"e5m2"
not
in
layer
.
kv_cache_dtype
:
print_warning_once
(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint."
)
del
layer
.
kv_scale
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
]
)
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
qweight
=
(
tensor
/
inv_scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
qweight
.
to
(
torch
.
float8_e4m3fn
)
def
per_tensor_dequantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
]
)
->
torch
.
Tensor
:
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
python/sglang/srt/layers/radix_attention.py
View file @
fb1f28cb
...
...
@@ -20,8 +20,8 @@ from flashinfer.cascade import merge_state
from
torch
import
nn
from
sglang.global_config
import
global_config
from
sglang.srt.layers.decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.model_runner
import
global_server_args_dict
...
...
@@ -95,7 +95,7 @@ class RadixAttention(nn.Module):
o
=
torch
.
empty_like
(
q
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
token
_attention_fwd
(
decode
_attention_fwd
(
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
qk_head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment