Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1acccb36
Unverified
Commit
1acccb36
authored
Sep 18, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 18, 2024
Browse files
Fix oom issues with fp8 for llama (#1454)
parent
aa2750be
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
31 additions
and
19 deletions
+31
-19
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+4
-4
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+1
-3
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+2
-3
python/sglang/srt/models/xverse.py
python/sglang/srt/models/xverse.py
+1
-3
python/sglang/srt/models/xverse_moe.py
python/sglang/srt/models/xverse_moe.py
+1
-4
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-0
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+12
-0
test/srt/test_chunked_prefill.py
test/srt/test_chunked_prefill.py
+9
-2
No files found.
.github/workflows/pr-test.yml
View file @
1acccb36
...
...
@@ -144,17 +144,17 @@ jobs:
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_radix_cache
-
name
:
Benchmark Offline Throughput (w/
o ChunkedPrefill
)
-
name
:
Benchmark Offline Throughput (w/
Triton
)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with
out_chunked_prefill
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with
_triton_attention_backend
-
name
:
Benchmark Offline Throughput (w/
Triton
)
-
name
:
Benchmark Offline Throughput (w/
FP8
)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_
with_triton_attention_backend
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_
default_fp8
performance-test-2-gpu
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
...
python/sglang/srt/models/llama.py
View file @
1acccb36
...
...
@@ -305,8 +305,6 @@ class LlamaForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
param_dict
=
dict
(
self
.
named_parameters
())
@
torch
.
no_grad
()
def
forward
(
self
,
...
...
@@ -374,7 +372,7 @@ class LlamaForCausalLM(nn.Module):
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
self
.
param_dict
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
...
...
python/sglang/srt/models/llama_classification.py
View file @
1acccb36
...
...
@@ -36,6 +36,7 @@ class LlamaForClassification(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
torchao_config
=
None
self
.
quant_config
=
quant_config
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
...
...
@@ -44,8 +45,6 @@ class LlamaForClassification(nn.Module):
)
self
.
eos_token_id
=
config
.
eos_token_id
self
.
param_dict
=
dict
(
self
.
named_parameters
())
@
torch
.
no_grad
()
def
forward
(
self
,
...
...
@@ -77,7 +76,7 @@ class LlamaForClassification(nn.Module):
return
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
self
.
param_dict
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"classification_head"
in
name
:
...
...
python/sglang/srt/models/xverse.py
View file @
1acccb36
...
...
@@ -307,8 +307,6 @@ class XverseForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
param_dict
=
dict
(
self
.
named_parameters
())
@
torch
.
no_grad
()
def
forward
(
self
,
...
...
@@ -333,7 +331,7 @@ class XverseForCausalLM(nn.Module):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
self
.
param_dict
params_dict
=
dict
(
self
.
named_parameters
())
def
load_weights_per_param
(
name
,
loaded_weight
):
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
...
...
python/sglang/srt/models/xverse_moe.py
View file @
1acccb36
...
...
@@ -383,8 +383,6 @@ class XverseMoeForCausalLM(nn.Module):
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
param_dict
=
dict
(
self
.
named_parameters
())
@
torch
.
no_grad
()
def
forward
(
self
,
...
...
@@ -406,8 +404,7 @@ class XverseMoeForCausalLM(nn.Module):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
self
.
param_dict
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
...
...
python/sglang/test/test_utils.py
View file @
1acccb36
...
...
@@ -22,6 +22,7 @@ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from
sglang.srt.utils
import
kill_child_process
from
sglang.utils
import
get_exception_traceback
DEFAULT_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/Meta-Llama-3.1-8B-FP8"
DEFAULT_MODEL_NAME_FOR_TEST
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
=
600
...
...
test/srt/test_bench_serving.py
View file @
1acccb36
import
unittest
from
sglang.test.test_utils
import
(
DEFAULT_FP8_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MOE_MODEL_NAME_FOR_TEST
,
is_in_ci
,
...
...
@@ -59,6 +60,17 @@ class TestBenchServing(unittest.TestCase):
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
2600
def
test_offline_throughput_default_fp8
(
self
):
res
=
run_bench_serving
(
model
=
DEFAULT_FP8_MODEL_NAME_FOR_TEST
,
num_prompts
=
500
,
request_rate
=
float
(
"inf"
),
other_server_args
=
[],
)
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
3100
def
test_online_latency_default
(
self
):
res
=
run_bench_serving
(
model
=
DEFAULT_MODEL_NAME_FOR_TEST
,
...
...
test/srt/test_chunked_prefill.py
View file @
1acccb36
...
...
@@ -12,8 +12,10 @@ from sglang.test.test_utils import (
class
TestChunkedPrefill
(
unittest
.
TestCase
):
def
run_mmlu
(
self
,
disable_radix_cache
,
enable_mixed_chunk
):
other_args
=
[
"--chunked-prefill-size"
,
"32"
]
def
run_mmlu
(
self
,
disable_radix_cache
,
enable_mixed_chunk
,
chunked_prefill_size
=
32
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
...
...
@@ -55,6 +57,11 @@ class TestChunkedPrefill(unittest.TestCase):
def
test_mixed_chunked_prefill_without_radix_cache
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
,
enable_mixed_chunk
=
True
)
def
test_no_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
,
chunked_prefill_size
=-
1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment