Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b8ab989f
Unverified
Commit
b8ab989f
authored
Jan 23, 2025
by
lukec
Committed by
GitHub
Jan 22, 2025
Browse files
Fix the FP8 E4M3 parsing offline scales failure bug (#3045)
parent
b3393e94
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
77 additions
and
1 deletion
+77
-1
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+77
-1
No files found.
python/sglang/srt/model_loader/weight_utils.py
View file @
b8ab989f
...
@@ -27,6 +27,7 @@ import huggingface_hub.constants
...
@@ -27,6 +27,7 @@ import huggingface_hub.constants
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
pydantic
import
BaseModel
,
ConfigDict
,
ValidationInfo
,
model_validator
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
...
@@ -650,6 +651,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
...
@@ -650,6 +651,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return
name
return
name
# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
class
KVCacheQuantSchema
(
BaseModel
):
dtype
:
str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor
:
Dict
[
int
,
Dict
[
int
,
float
]]
@
model_validator
(
mode
=
"after"
)
def
check_is_fp8
(
self
)
->
"KVCacheQuantSchema"
:
assert
self
.
dtype
==
"float8_e4m3fn"
,
(
"Loaded scaling factors intended for KV cache dtype = "
f
"
{
self
.
dtype
}
rather than float8_e4m3fn!"
)
return
self
@
model_validator
(
mode
=
"after"
)
def
check_tp_ranks
(
self
,
info
:
ValidationInfo
)
->
"KVCacheQuantSchema"
:
context
=
info
.
context
if
context
:
tp_size
=
context
[
"tp_size"
]
num_hidden_layers
=
context
[
"num_hidden_layers"
]
assert
len
(
self
.
scaling_factor
)
==
tp_size
,
(
f
"Loaded dictionary has TP size
{
len
(
self
.
scaling_factor
)
}
"
f
"but LLM engine is currently running with TP size
{
tp_size
}
."
)
for
tp_rank
,
layer_maps
in
self
.
scaling_factor
.
items
():
assert
len
(
layer_maps
)
==
num_hidden_layers
,
(
f
"KV cache scales map for TP rank
{
tp_rank
}
is malformed. "
f
"Expected
{
num_hidden_layers
}
layers, got "
f
"
{
len
(
layer_maps
)
}
."
)
for
i
in
range
(
tp_size
):
assert
(
i
in
self
.
scaling_factor
),
f
"KV cache scales map for TP rank
{
i
}
not found."
return
self
@
model_validator
(
mode
=
"after"
)
def
check_current_rank
(
self
,
info
:
ValidationInfo
)
->
"KVCacheQuantSchema"
:
context
=
info
.
context
if
context
:
tp_rank
=
context
[
"tp_rank"
]
num_hidden_layers
=
context
[
"num_hidden_layers"
]
layer_scales_map
=
self
.
scaling_factor
[
tp_rank
]
for
i
in
range
(
num_hidden_layers
):
assert
i
in
layer_scales_map
,
(
f
"Could not find KV cache scales for layer
{
i
}
in "
f
"TP rank
{
tp_rank
}
."
)
return
self
class
QuantParamSchema
(
BaseModel
):
# TODO: Generalize and extend with more fields
# (e.g. weights/activations params) once functionality is enabled
model_config
=
ConfigDict
(
protected_namespaces
=
())
model_type
:
Optional
[
str
]
kv_cache
:
KVCacheQuantSchema
@
model_validator
(
mode
=
"after"
)
def
check_model_type
(
self
,
info
:
ValidationInfo
)
->
"QuantParamSchema"
:
context
=
info
.
context
if
context
:
model_type
=
context
.
get
(
"model_type"
,
None
)
if
model_type
is
not
None
:
assert
model_type
==
self
.
model_type
,
(
f
"Model type is
{
model_type
}
but loaded "
f
"scaling factors belonging to different "
f
"model type
{
self
.
model_type
}
!"
)
return
self
def
kv_cache_scales_loader
(
def
kv_cache_scales_loader
(
filename
:
str
,
filename
:
str
,
tp_rank
:
int
,
tp_rank
:
int
,
...
@@ -681,7 +757,7 @@ def kv_cache_scales_loader(
...
@@ -681,7 +757,7 @@ def kv_cache_scales_loader(
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
logger
.
error
(
"Error decoding JSON in file '%s'."
,
filename
)
logger
.
error
(
"Error decoding JSON in file '%s'."
,
filename
)
except
Exception
:
except
Exception
:
logger
.
e
xception
(
"An error occurred while reading '%s'."
,
filename
)
logger
.
e
rror
(
"An error occurred while reading '%s'."
,
filename
)
# This section is reached if and only if any of the excepts are hit
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
# which ultimately defaults to 1.0 scales
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment