Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d63f13c1
Unverified
Commit
d63f13c1
authored
Jul 25, 2024
by
Ying Sheng
Committed by
GitHub
Jul 25, 2024
Browse files
fix: fp8 config (#723)
parent
fded6744
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
0 deletions
+51
-0
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+51
-0
No files found.
python/sglang/srt/managers/controller/model_runner.py
View file @
d63f13c1
...
@@ -15,6 +15,7 @@ from flashinfer import (
...
@@ -15,6 +15,7 @@ from flashinfer import (
BatchPrefillWithRaggedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
torch.nn.parameter
import
Parameter
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.distributed
import
(
from
vllm.distributed
import
(
...
@@ -22,6 +23,7 @@ from vllm.distributed import (
...
@@ -22,6 +23,7 @@ from vllm.distributed import (
init_distributed_environment
,
init_distributed_environment
,
initialize_model_parallel
,
initialize_model_parallel
,
)
)
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
...
@@ -38,6 +40,18 @@ from sglang.srt.utils import (
...
@@ -38,6 +40,18 @@ from sglang.srt.utils import (
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
def
is_llama3_405b_fp8
(
model_config
):
if
(
model_config
.
hf_config
.
architectures
[
0
]
==
"LlamaForCausalLM"
and
model_config
.
hf_config
.
hidden_size
==
16384
and
model_config
.
hf_config
.
intermediate_size
==
53248
and
model_config
.
hf_config
.
num_hidden_layers
==
126
and
model_config
.
hf_config
.
quantization_config
[
"quant_method"
]
==
"fbgemm_fp8"
):
return
True
return
False
class
ModelRunner
:
class
ModelRunner
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -118,6 +132,9 @@ class ModelRunner:
...
@@ -118,6 +132,9 @@ class ModelRunner:
seed
=
42
,
seed
=
42
,
skip_tokenizer_init
=
True
,
skip_tokenizer_init
=
True
,
)
)
if
is_llama3_405b_fp8
(
self
.
model_config
):
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
dtype
=
vllm_model_config
.
dtype
self
.
dtype
=
vllm_model_config
.
dtype
if
self
.
model_config
.
model_overide_args
is
not
None
:
if
self
.
model_config
.
model_overide_args
is
not
None
:
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_overide_args
)
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_overide_args
)
...
@@ -370,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
...
@@ -370,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
return
model_arch_name_to_cls
[
model_arch
]
return
model_arch_name_to_cls
[
model_arch
]
def
get_original_weight
(
loaded_weight
,
head_dim
):
n_kv_head
=
loaded_weight
.
shape
[
0
]
//
(
2
*
head_dim
)
dim
=
loaded_weight
.
shape
[
1
]
for
i
in
range
(
n_kv_head
):
loaded_weight
[
i
*
head_dim
:
(
i
+
1
)
*
head_dim
,
:]
=
loaded_weight
[
2
*
i
*
head_dim
:
(
2
*
i
+
1
)
*
head_dim
,
:
]
original_kv_weight
=
loaded_weight
[:
n_kv_head
*
head_dim
,
:]
assert
original_kv_weight
.
shape
==
(
n_kv_head
*
head_dim
,
dim
)
return
original_kv_weight
def
get_weight_loader_srt
(
weight_loader
):
def
weight_loader_srt
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
,
):
if
(
loaded_shard_id
in
[
"k"
,
"v"
]
and
loaded_weight
.
shape
[
0
]
==
self
.
head_size
*
self
.
total_num_kv_heads
*
2
):
loaded_weight
=
get_original_weight
(
loaded_weight
,
self
.
head_size
)
weight_loader
(
self
,
param
,
loaded_weight
,
loaded_shard_id
)
return
weight_loader_srt
# Monkey patch model loader
# Monkey patch model loader
setattr
(
ModelRegistry
,
"load_model_cls"
,
load_model_cls_srt
)
setattr
(
ModelRegistry
,
"load_model_cls"
,
load_model_cls_srt
)
original_weight_loader
=
QKVParallelLinear
.
weight_loader
setattr
(
QKVParallelLinear
,
"weight_loader"
,
get_weight_loader_srt
(
original_weight_loader
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment