Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
ae0f6130
You need to sign in or sign up before continuing.
Unverified
Commit
ae0f6130
authored
Jul 25, 2024
by
Ying Sheng
Committed by
GitHub
Jul 25, 2024
Browse files
Revert "fix: fp8 config" (#728)
parent
60105897
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
51 deletions
+0
-51
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+0
-51
No files found.
python/sglang/srt/managers/controller/model_runner.py
View file @
ae0f6130
...
...
@@ -15,7 +15,6 @@ from flashinfer import (
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
torch.nn.parameter
import
Parameter
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.distributed
import
(
...
...
@@ -23,7 +22,6 @@ from vllm.distributed import (
init_distributed_environment
,
initialize_model_parallel
,
)
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
...
...
@@ -40,18 +38,6 @@ from sglang.srt.utils import (
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
def
is_llama3_405b_fp8
(
model_config
):
if
(
model_config
.
hf_config
.
architectures
[
0
]
==
"LlamaForCausalLM"
and
model_config
.
hf_config
.
hidden_size
==
16384
and
model_config
.
hf_config
.
intermediate_size
==
53248
and
model_config
.
hf_config
.
num_hidden_layers
==
126
and
model_config
.
hf_config
.
quantization_config
[
"quant_method"
]
==
"fbgemm_fp8"
):
return
True
return
False
class
ModelRunner
:
def
__init__
(
self
,
...
...
@@ -132,9 +118,6 @@ class ModelRunner:
seed
=
42
,
skip_tokenizer_init
=
True
,
)
if
is_llama3_405b_fp8
(
self
.
model_config
):
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
dtype
=
vllm_model_config
.
dtype
if
self
.
model_config
.
model_overide_args
is
not
None
:
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_overide_args
)
...
...
@@ -387,39 +370,5 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
return
model_arch_name_to_cls
[
model_arch
]
def
get_original_weight
(
loaded_weight
,
head_dim
):
n_kv_head
=
loaded_weight
.
shape
[
0
]
//
(
2
*
head_dim
)
dim
=
loaded_weight
.
shape
[
1
]
for
i
in
range
(
n_kv_head
):
loaded_weight
[
i
*
head_dim
:
(
i
+
1
)
*
head_dim
,
:]
=
loaded_weight
[
2
*
i
*
head_dim
:
(
2
*
i
+
1
)
*
head_dim
,
:
]
original_kv_weight
=
loaded_weight
[:
n_kv_head
*
head_dim
,
:]
assert
original_kv_weight
.
shape
==
(
n_kv_head
*
head_dim
,
dim
)
return
original_kv_weight
def
get_weight_loader_srt
(
weight_loader
):
def
weight_loader_srt
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
,
):
if
(
loaded_shard_id
in
[
"k"
,
"v"
]
and
loaded_weight
.
shape
[
0
]
==
self
.
head_size
*
self
.
total_num_kv_heads
*
2
):
loaded_weight
=
get_original_weight
(
loaded_weight
,
self
.
head_size
)
weight_loader
(
self
,
param
,
loaded_weight
,
loaded_shard_id
)
return
weight_loader_srt
# Monkey patch model loader
setattr
(
ModelRegistry
,
"load_model_cls"
,
load_model_cls_srt
)
original_weight_loader
=
QKVParallelLinear
.
weight_loader
setattr
(
QKVParallelLinear
,
"weight_loader"
,
get_weight_loader_srt
(
original_weight_loader
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment