Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9c8e4f69
Unverified
Commit
9c8e4f69
authored
Aug 22, 2025
by
Hongbo Xu
Committed by
GitHub
Aug 21, 2025
Browse files
[5/n]decouple quantization implementation from vLLM dependency (#9454)
parent
78ae1758
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
562 additions
and
1 deletion
+562
-1
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+4
-1
python/sglang/srt/layers/quantization/fpgemm_fp8.py
python/sglang/srt/layers/quantization/fpgemm_fp8.py
+199
-0
python/sglang/srt/layers/quantization/marlin_utils.py
python/sglang/srt/layers/quantization/marlin_utils.py
+7
-0
python/sglang/srt/layers/quantization/marlin_utils_fp8.py
python/sglang/srt/layers/quantization/marlin_utils_fp8.py
+352
-0
No files found.
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
9c8e4f69
...
@@ -557,7 +557,10 @@ def apply_fp8_linear(
...
@@ -557,7 +557,10 @@ def apply_fp8_linear(
# We also don't pad when using torch.compile,
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
# as it breaks with dynamic shapes.
if
pad_output
is
None
:
if
pad_output
is
None
:
pad_output
=
not
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
)
pad_output
=
(
not
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
)
and
not
cutlass_fp8_supported
)
output_padding
=
17
if
pad_output
else
None
output_padding
=
17
if
pad_output
else
None
# View input as 2D matrix for fp8 methods
# View input as 2D matrix for fp8 methods
...
...
python/sglang/srt/layers/quantization/fpgemm_fp8.py
0 → 100644
View file @
9c8e4f69
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
logging
from
typing
import
Any
,
Optional
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
from
sglang.srt.layers.parameter
import
ChannelQuantScaleParameter
,
ModelWeightParameter
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
LinearMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
can_auto_enable_marlin_fp8
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.quantization.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
)
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
,
replace_parameter
from
sglang.srt.utils
import
get_bool_env_var
,
is_cuda
,
is_fp8_fnuz
_is_cuda
=
is_cuda
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
logger
=
logging
.
getLogger
(
__name__
)
class
FBGEMMFp8Config
(
QuantizationConfig
):
"""Config class for FBGEMM Fp8."""
def
__init__
(
self
,
ignore_list
:
list
[
str
],
input_scale_ub
:
float
):
super
().
__init__
()
self
.
ignore_list
=
ignore_list
if
ignore_list
else
[]
self
.
input_scale_ub
=
input_scale_ub
# For GPUs that lack FP8 hardware suspport, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
# self.use_marlin = not marlin_fp8_supported()
self
.
use_marlin
=
False
if
_is_cuda
:
force_marlin
=
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
auto_enable
=
can_auto_enable_marlin_fp8
()
self
.
use_marlin
=
force_marlin
or
auto_enable
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"fbgemm_fp8"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
float16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
list
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
FBGEMMFp8Config
:
ignore_list
=
cls
.
get_from_keys
(
config
,
[
"modules_to_not_convert"
])
input_scale_ub
=
cls
.
get_from_keys
(
config
,
[
"activation_scale_ub"
])
return
cls
(
ignore_list
=
ignore_list
,
input_scale_ub
=
input_scale_ub
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
=
prefix
,
ignored_layers
=
self
.
ignore_list
,
fused_mapping
=
self
.
packed_modules_mapping
,
):
return
UnquantizedLinearMethod
()
return
FBGEMMFp8LinearMethod
(
self
)
return
None
class
FBGEMMFp8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
self
.
quant_config
=
quant_config
# self.fp8_linear = Fp8LinearOp(
# act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# maybe_create_device_identity()
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
# WEIGHT SCALE
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE UPPER BOUND
input_scale_ub
=
torch
.
nn
.
Parameter
(
torch
.
tensor
((
self
.
quant_config
.
input_scale_ub
),
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
input_scale_ub
=
input_scale_ub
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# required by torch.compile
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
weight
=
layer
.
weight
if
_is_fp8_fnuz
:
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
if
self
.
quant_config
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
del
layer
.
input_scale_ub
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
use_marlin
:
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
input_scale_ub
=
layer
.
input_scale_ub
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
False
,
)
python/sglang/srt/layers/quantization/marlin_utils.py
View file @
9c8e4f69
...
@@ -306,6 +306,13 @@ def marlin_permute_scales(
...
@@ -306,6 +306,13 @@ def marlin_permute_scales(
return
s
return
s
def
marlin_permute_bias
(
s
:
torch
.
Tensor
)
->
torch
.
Tensor
:
origin_shape
=
s
.
shape
_
,
scale_perm_single
=
get_scale_perms
()
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
return
s
.
reshape
(
*
origin_shape
).
contiguous
()
def
marlin_moe_permute_scales
(
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
size_k
:
int
,
size_k
:
int
,
...
...
python/sglang/srt/layers/quantization/marlin_utils_fp8.py
0 → 100644
View file @
9c8e4f69
# SPDX-License-Identifier: Apache-2.0
import
logging
from
typing
import
Optional
import
torch
from
sglang.srt.layers.quantization.marlin_utils
import
(
USE_FP32_REDUCE_DEFAULT
,
marlin_make_workspace
,
marlin_permute_bias
,
marlin_permute_scales
,
should_use_atomic_add_reduce
,
)
from
sglang.srt.layers.quantization.utils
import
get_scalar_types
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
gptq_marlin_gemm
,
gptq_marlin_repack
ScalarType
,
scalar_types
=
get_scalar_types
()
logger
=
logging
.
getLogger
(
__name__
)
def
fp8_fused_exponent_bias_into_scales
(
scales
):
fp8_exponent
=
4
if
scales
.
dtype
==
torch
.
half
:
target_exponent
=
5
elif
scales
.
dtype
==
torch
.
bfloat16
:
target_exponent
=
8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
exponent_bias
=
2
**
(
target_exponent
-
1
)
-
2
**
(
fp8_exponent
-
1
)
s
=
torch
.
ones_like
(
scales
)
*
2
s
=
s
**
exponent_bias
return
scales
*
s
def
apply_fp8_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
bias
:
Optional
[
torch
.
Tensor
],
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
,
)
->
torch
.
Tensor
:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
size_n
,)
use_atomic_add
=
should_use_atomic_add_reduce
(
m
=
reshaped_x
.
size
(
0
),
n
=
size_n
,
k
=
size_k
,
device
=
input
.
device
,
dtype
=
input
.
dtype
)
output
=
gptq_marlin_gemm
(
a
=
reshaped_x
,
c
=
None
,
b_q_weight
=
weight
,
b_bias
=
bias
,
b_scales
=
weight_scale
,
global_scale
=
None
,
b_zeros
=
None
,
g_idx
=
None
,
perm
=
None
,
workspace
=
workspace
,
b_q_type
=
scalar_types
.
float8_e4m3fn
,
size_m
=
reshaped_x
.
size
(
0
),
size_n
=
size_n
,
size_k
=
size_k
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
,
)
return
output
.
reshape
(
out_shape
)
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
size_k_first
:
bool
=
True
)
->
None
:
logger
.
warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
weight_block_size
=
getattr
(
layer
,
"weight_block_size"
,
None
)
if
size_k_first
:
assert
layer
.
weight
.
shape
==
(
part_size_k
,
part_size_n
)
else
:
assert
layer
.
weight
.
shape
==
(
part_size_n
,
part_size_k
)
device
=
layer
.
weight
.
device
# WORKSPACE
layer
.
workspace
=
marlin_make_workspace
(
device
)
# WEIGHT
# Repack weights to marlin format
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
qweight
=
pack_fp8_to_int32
(
layer
.
weight
,
size_k_first
)
if
not
size_k_first
:
qweight
=
qweight
.
T
.
contiguous
()
marlin_qweight
=
gptq_marlin_repack
(
b_q_weight
=
qweight
,
perm
=
perm
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
,
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Permute scales
if
"weight_scale"
in
dir
(
layer
):
scales
=
layer
.
weight_scale
.
to
(
layer
.
orig_dtype
)
elif
"weight_scale_inv"
in
dir
(
layer
):
scales
=
layer
.
weight_scale_inv
.
to
(
layer
.
orig_dtype
)
del
layer
.
weight_scale_inv
group_size
=
-
1
if
weight_block_size
is
None
else
weight_block_size
[
1
]
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if
weight_block_size
is
None
:
if
scales
.
nelement
()
==
1
:
# tensor-wise quantization -> channel-wise quantization
# (1, 1) =>(repeat)=> (1, size_n)
scales
=
scales
.
view
(
1
,
1
).
repeat_interleave
(
part_size_n
,
1
)
elif
scales
.
nelement
()
>
1
and
scales
.
nelement
()
!=
part_size_n
:
assert
part_size_n
%
scales
.
nelement
()
==
0
s_size
=
scales
.
nelement
()
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (1, s_size) =>(repeat)=> (1, size_n)
scales
=
scales
.
view
(
1
,
s_size
)
scales
=
scales
.
repeat_interleave
(
part_size_n
//
s_size
,
1
)
else
:
# channel-wise quantization
# (1, size_n)
scales
=
scales
.
view
(
1
,
part_size_n
)
else
:
# block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n)
if
not
size_k_first
:
scales
=
scales
.
T
.
contiguous
()
block_n
=
weight_block_size
[
0
]
scales
=
scales
.
repeat_interleave
(
block_n
,
1
)
# size_n may not divisible by block_size[0]
scales
=
scales
[:,
:
part_size_n
]
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=
group_size
)
marlin_scales
=
fp8_fused_exponent_bias_into_scales
(
marlin_scales
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
marlin_scales
,
requires_grad
=
False
)
if
hasattr
(
layer
,
"bias"
)
and
layer
.
bias
is
not
None
:
assert
layer
.
bias
.
shape
==
(
part_size_n
,)
bias
=
marlin_permute_bias
(
layer
.
bias
)
layer
.
bias
=
torch
.
nn
.
Parameter
(
bias
,
requires_grad
=
False
)
def
prepare_moe_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
size_k_first
:
bool
=
True
)
->
None
:
logger
.
warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
e
=
layer
.
num_experts
k
=
layer
.
hidden_size
n
=
layer
.
intermediate_size_per_partition
weight_block_size
=
getattr
(
layer
,
"weight_block_size"
,
None
)
# WORKSPACE
device
=
layer
.
w13_weight
.
device
layer
.
workspace
=
marlin_make_workspace
(
device
,
4
)
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
# WEIGHT
# Repack weights to marlin format
for
name
in
[
"w13_weight"
,
"w2_weight"
]:
weight
=
getattr
(
layer
,
name
)
tensor_list
=
[]
if
"w13"
in
name
:
size_n
,
size_k
=
n
*
2
,
k
else
:
size_n
,
size_k
=
k
,
n
if
size_k_first
:
assert
weight
.
shape
==
(
e
,
size_k
,
size_n
)
else
:
assert
weight
.
shape
==
(
e
,
size_n
,
size_k
)
for
i
in
range
(
e
):
qweight
=
pack_fp8_to_int32
(
weight
[
i
],
size_k_first
)
if
not
size_k_first
:
qweight
=
qweight
.
T
.
contiguous
()
marlin_qweight
=
gptq_marlin_repack
(
b_q_weight
=
qweight
,
perm
=
perm
,
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
)
tensor_list
.
append
(
marlin_qweight
)
weight
=
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
setattr
(
layer
,
name
,
weight
)
# WEIGHT SCALES
# Permute scales
group_size
=
-
1
if
weight_block_size
is
None
else
weight_block_size
[
1
]
for
name
in
[
"w13"
,
"w2"
]:
if
name
+
"_weight_scale"
in
dir
(
layer
):
new_name
=
name
+
"_weight_scale"
scales
=
getattr
(
layer
,
new_name
).
to
(
layer
.
orig_dtype
)
delattr
(
layer
,
new_name
)
elif
name
+
"_weight_scale_inv"
in
dir
(
layer
):
new_name
=
name
+
"_weight_scale_inv"
scales
=
getattr
(
layer
,
new_name
).
to
(
layer
.
orig_dtype
)
delattr
(
layer
,
new_name
)
tensor_list
=
[]
if
"w13"
in
name
:
size_n
,
size_k
=
n
*
2
,
k
else
:
size_n
,
size_k
=
k
,
n
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if
weight_block_size
is
None
:
if
scales
.
nelement
()
==
e
:
# tensor-wise quantization -> channel-wise quantization
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
scales
=
scales
.
view
(
e
,
1
,
1
).
repeat_interleave
(
size_n
,
2
)
elif
scales
.
nelement
()
>
e
and
scales
.
nelement
()
!=
e
*
size_n
:
assert
(
e
*
size_n
)
%
scales
.
nelement
()
==
0
s_size
=
scales
.
nelement
()
//
e
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
scales
=
scales
.
view
(
e
,
1
,
s_size
)
scales
=
scales
.
repeat_interleave
(
size_n
//
s_size
,
2
)
else
:
# channel-wise quantization
# (e, 1, size_n)
scales
=
scales
.
view
(
e
,
1
,
size_n
)
else
:
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
if
not
size_k_first
:
scales
=
scales
.
permute
(
0
,
2
,
1
)
block_n
=
weight_block_size
[
0
]
scales
=
scales
.
repeat_interleave
(
block_n
,
2
)
# size_n may not divisible by block_size[0]
scales
=
scales
[...,
:
size_n
].
contiguous
()
for
i
in
range
(
e
):
marlin_scales
=
marlin_permute_scales
(
s
=
scales
[
i
],
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=
group_size
)
tensor_list
.
append
(
marlin_scales
)
scales
=
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
scales
=
fp8_fused_exponent_bias_into_scales
(
scales
)
scales
=
torch
.
nn
.
Parameter
(
scales
,
requires_grad
=
False
)
setattr
(
layer
,
name
+
"_weight_scale"
,
scales
)
# BIAS
# Permute bias
for
name
in
[
"w13_bias"
,
"w2_bias"
]:
if
not
hasattr
(
layer
,
name
):
continue
bias
=
getattr
(
layer
,
name
).
to
(
layer
.
orig_dtype
)
tensor_list
=
[]
for
i
in
range
(
e
):
expert_bias
=
bias
[
i
]
tensor_list
.
append
(
marlin_permute_bias
(
expert_bias
))
bias
=
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
bias
=
torch
.
nn
.
Parameter
(
bias
,
requires_grad
=
False
)
setattr
(
layer
,
name
,
bias
)
def
pack_fp8_to_int32
(
fp8_tensor
:
torch
.
Tensor
,
size_k_first
:
bool
=
True
)
->
torch
.
Tensor
:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert
fp8_tensor
.
dtype
==
torch
.
float8_e4m3fn
assert
fp8_tensor
.
ndim
==
2
fp8_tensor
=
fp8_tensor
.
T
if
size_k_first
else
fp8_tensor
fp8_tensor
=
fp8_tensor
.
contiguous
()
# fp8_tensor is contiguous and have shape (N, K) now
# with `.view(torch.int32)`, it become (N, K // 4)
int32_tensor
=
fp8_tensor
.
view
(
torch
.
int32
)
return
int32_tensor
.
T
.
contiguous
()
if
size_k_first
else
int32_tensor
def
marlin_quant_fp8_torch
(
weight
,
group_size
):
size_n
,
size_k
=
weight
.
shape
device
=
weight
.
device
if
group_size
!=
-
1
:
scales
=
weight
.
view
(
size_n
,
-
1
,
group_size
).
abs
().
max
(
-
1
)[
0
]
/
448
repeated_scales
=
scales
.
repeat_interleave
(
group_size
,
1
)
fp8_weight
=
(
weight
/
repeated_scales
).
to
(
torch
.
float8_e4m3fn
)
weight_ref
=
fp8_weight
.
to
(
weight
.
dtype
)
*
repeated_scales
else
:
scales
=
weight
.
view
(
size_n
,
1
,
group_size
).
abs
().
max
(
-
1
)[
0
]
/
448
repeated_scales
=
scales
.
repeat_interleave
(
size_k
,
1
)
fp8_weight
=
(
weight
/
repeated_scales
).
to
(
torch
.
float8_e4m3fn
)
weight_ref
=
fp8_weight
.
to
(
weight
.
dtype
)
*
repeated_scales
packed_weight
=
pack_fp8_to_int32
(
fp8_weight
,
False
).
T
.
contiguous
()
marlin_qweight
=
gptq_marlin_repack
(
b_q_weight
=
packed_weight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
,
)
marlin_scales
=
marlin_permute_scales
(
s
=
scales
.
T
,
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=
group_size
)
marlin_scales
=
fp8_fused_exponent_bias_into_scales
(
marlin_scales
)
return
weight_ref
.
T
,
marlin_qweight
,
marlin_scales
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment