Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
30828e71
Unverified
Commit
30828e71
authored
Dec 29, 2024
by
HAI
Committed by
GitHub
Dec 29, 2024
Browse files
AMD: set weights and scaling numbers properly for block FP8 (#2637)
parent
e0e09fce
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
56 additions
and
6 deletions
+56
-6
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+38
-1
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+10
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+8
-2
No files found.
python/sglang/srt/layers/quantization/fp8.py
View file @
30828e71
...
...
@@ -272,6 +272,19 @@ class Fp8LinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# activation_scheme: dynamic
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
None
,
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
require_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
weight_scale
,
require_grad
=
False
)
layer
.
input_scale
=
None
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
...
...
@@ -369,7 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
,
block_size
=
self
.
quant_config
.
weight_block_size
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
layer
.
input_scal
e
,
input_scale
=
Non
e
,
bias
=
bias
,
)
...
...
@@ -553,6 +566,30 @@ class Fp8MoEMethod:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# activation_scheme: dynamic
w13_weight
,
w13_weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
w13_weight
,
weight_scale
=
layer
.
w13_weight_scale_inv
,
input_scale
=
None
,
)
w2_weight
,
w2_weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
w2_weight
,
weight_scale
=
layer
.
w2_weight_scale_inv
,
input_scale
=
None
,
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale_inv
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
layer
.
w13_input_scale
=
None
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale_inv
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
layer
.
w2_input_scale
=
None
return
# If checkpoint is fp16 or bfloat16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
30828e71
...
...
@@ -22,7 +22,10 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.utils
import
get_device_name
from
sglang.srt.utils
import
get_device_name
,
is_hip
is_hip_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -73,7 +76,7 @@ def per_token_group_quant_fp8(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fn
,
dtype
:
torch
.
dtype
=
fp8_type_
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
...
...
@@ -95,9 +98,13 @@ def per_token_group_quant_fp8(
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_min
=
finfo
.
min
fp8_max
=
finfo
.
max
if
is_hip_
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
30828e71
...
...
@@ -7,6 +7,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
,
)
from
sglang.srt.utils
import
is_hip
is_hip_
=
is_hip
()
def
normalize_e4m3fn_to_e4m3fnuz
(
...
...
@@ -63,8 +66,11 @@ def input_to_float8(
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
scale
=
finfo
.
max
/
amax
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
fp8_max
=
finfo
.
max
if
is_hip_
:
fp8_max
=
224.0
scale
=
fp8_max
/
amax
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=-
fp8_max
,
max
=
fp8_max
)
return
x_scl_sat
.
to
(
dtype
).
contiguous
(),
scale
.
float
().
reciprocal
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment