Unverified Commit 1acccb36 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix oom issues with fp8 for llama (#1454)

parent aa2750be
...@@ -144,17 +144,17 @@ jobs: ...@@ -144,17 +144,17 @@ jobs:
cd test/srt cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_radix_cache python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_radix_cache
- name: Benchmark Offline Throughput (w/o ChunkedPrefill) - name: Benchmark Offline Throughput (w/ Triton)
timeout-minutes: 10 timeout-minutes: 10
run: | run: |
cd test/srt cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_chunked_prefill python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with_triton_attention_backend
- name: Benchmark Offline Throughput (w/ Triton) - name: Benchmark Offline Throughput (w/ FP8)
timeout-minutes: 10 timeout-minutes: 10
run: | run: |
cd test/srt cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with_triton_attention_backend python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8
performance-test-2-gpu: performance-test-2-gpu:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......
...@@ -305,8 +305,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -305,8 +305,6 @@ class LlamaForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.param_dict = dict(self.named_parameters())
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
...@@ -374,7 +372,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -374,7 +372,7 @@ class LlamaForCausalLM(nn.Module):
(".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = self.param_dict params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
......
...@@ -36,6 +36,7 @@ class LlamaForClassification(nn.Module): ...@@ -36,6 +36,7 @@ class LlamaForClassification(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.torchao_config = None
self.quant_config = quant_config self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
...@@ -44,8 +45,6 @@ class LlamaForClassification(nn.Module): ...@@ -44,8 +45,6 @@ class LlamaForClassification(nn.Module):
) )
self.eos_token_id = config.eos_token_id self.eos_token_id = config.eos_token_id
self.param_dict = dict(self.named_parameters())
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
...@@ -77,7 +76,7 @@ class LlamaForClassification(nn.Module): ...@@ -77,7 +76,7 @@ class LlamaForClassification(nn.Module):
return logits_output return logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = self.param_dict params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "classification_head" in name: if "classification_head" in name:
......
...@@ -307,8 +307,6 @@ class XverseForCausalLM(nn.Module): ...@@ -307,8 +307,6 @@ class XverseForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.param_dict = dict(self.named_parameters())
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
...@@ -333,7 +331,7 @@ class XverseForCausalLM(nn.Module): ...@@ -333,7 +331,7 @@ class XverseForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = self.param_dict params_dict = dict(self.named_parameters())
def load_weights_per_param(name, loaded_weight): def load_weights_per_param(name, loaded_weight):
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
......
...@@ -383,8 +383,6 @@ class XverseMoeForCausalLM(nn.Module): ...@@ -383,8 +383,6 @@ class XverseMoeForCausalLM(nn.Module):
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.param_dict = dict(self.named_parameters())
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
...@@ -406,8 +404,7 @@ class XverseMoeForCausalLM(nn.Module): ...@@ -406,8 +404,7 @@ class XverseMoeForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters())
params_dict = self.param_dict
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
......
...@@ -22,6 +22,7 @@ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint ...@@ -22,6 +22,7 @@ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct" DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
......
import unittest import unittest
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_FP8_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MOE_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST,
is_in_ci, is_in_ci,
...@@ -59,6 +60,17 @@ class TestBenchServing(unittest.TestCase): ...@@ -59,6 +60,17 @@ class TestBenchServing(unittest.TestCase):
if is_in_ci(): if is_in_ci():
assert res["output_throughput"] > 2600 assert res["output_throughput"] > 2600
def test_offline_throughput_default_fp8(self):
res = run_bench_serving(
model=DEFAULT_FP8_MODEL_NAME_FOR_TEST,
num_prompts=500,
request_rate=float("inf"),
other_server_args=[],
)
if is_in_ci():
assert res["output_throughput"] > 3100
def test_online_latency_default(self): def test_online_latency_default(self):
res = run_bench_serving( res = run_bench_serving(
model=DEFAULT_MODEL_NAME_FOR_TEST, model=DEFAULT_MODEL_NAME_FOR_TEST,
......
...@@ -12,8 +12,10 @@ from sglang.test.test_utils import ( ...@@ -12,8 +12,10 @@ from sglang.test.test_utils import (
class TestChunkedPrefill(unittest.TestCase): class TestChunkedPrefill(unittest.TestCase):
def run_mmlu(self, disable_radix_cache, enable_mixed_chunk): def run_mmlu(
other_args = ["--chunked-prefill-size", "32"] self, disable_radix_cache, enable_mixed_chunk, chunked_prefill_size=32
):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
if disable_radix_cache: if disable_radix_cache:
other_args += ["--disable-radix-cache"] other_args += ["--disable-radix-cache"]
...@@ -55,6 +57,11 @@ class TestChunkedPrefill(unittest.TestCase): ...@@ -55,6 +57,11 @@ class TestChunkedPrefill(unittest.TestCase):
def test_mixed_chunked_prefill_without_radix_cache(self): def test_mixed_chunked_prefill_without_radix_cache(self):
self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=True) self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=True)
def test_no_chunked_prefill(self):
self.run_mmlu(
disable_radix_cache=False, enable_mixed_chunk=False, chunked_prefill_size=-1
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment