Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
502524e2
Unverified
Commit
502524e2
authored
Apr 21, 2025
by
Juwan Yoo
Committed by
GitHub
Apr 20, 2025
Browse files
compressed_tensors: port w8a16 fp8 from vllm (#4852)
parent
4c764007
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
156 additions
and
0 deletions
+156
-0
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-0
python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+2
-0
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+153
-0
No files found.
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
502524e2
...
@@ -33,6 +33,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
...
@@ -33,6 +33,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsScheme
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A16Fp8
,
)
)
from
sglang.srt.layers.quantization.compressed_tensors.utils
import
(
from
sglang.srt.layers.quantization.compressed_tensors.utils
import
(
find_matched_target
,
find_matched_target
,
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
502524e2
...
@@ -2,8 +2,10 @@
...
@@ -2,8 +2,10 @@
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a16_fp8
import
CompressedTensorsW8A16Fp8
__all__
=
[
__all__
=
[
"CompressedTensorsScheme"
,
"CompressedTensorsScheme"
,
"CompressedTensorsW8A8Fp8"
,
"CompressedTensorsW8A8Fp8"
,
"CompressedTensorsW8A16Fp8"
,
]
]
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
0 → 100644
View file @
502524e2
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Callable
,
List
,
Optional
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
sglang.srt.layers.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
,
)
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
)
from
sglang.srt.layers.quantization.utils
import
convert_to_channelwise
try
:
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
)
MARLIN_FP8_AVAILABLE
=
True
except
ImportError
:
MARLIN_FP8_AVAILABLE
=
False
def
apply_fp8_marlin_linear
(
*
args
,
**
kwargs
):
raise
ImportError
(
"vllm is not installed"
)
def
prepare_fp8_layer_for_marlin
(
*
args
,
**
kwargs
):
raise
ImportError
(
"vllm is not installed"
)
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
SUPPORTED_STRATEGIES
=
[
QuantizationStrategy
.
CHANNEL
,
QuantizationStrategy
.
TENSOR
]
class
CompressedTensorsW8A16Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
if
not
MARLIN_FP8_AVAILABLE
:
raise
ImportError
(
"vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# ampere and up
return
80
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
ws_channelwise
=
convert_to_channelwise
(
layer
.
weight_scale
,
layer
.
logical_widths
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
ws_channelwise
,
requires_grad
=
False
)
else
:
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
# Weights must be transposed for marlin
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
if
self
.
is_static_input_scheme
:
# required by torch.compile to be torch.nn.Parameter
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
prepare_fp8_layer_for_marlin
(
layer
,
strategy
=
"channel"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
,
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
# WEIGHT SCALE
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
elif
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
else
:
raise
ValueError
(
f
"Unsupported weight strategy=
{
self
.
strategy
}
, "
f
"supported strategies are
{
SUPPORTED_STRATEGIES
}
"
)
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE (to deal with converted checkpoints)
if
self
.
is_static_input_scheme
:
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment