Unverified Commit 7d671e4a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Enable overlap by default (#2067)

parent 699384cb
...@@ -220,7 +220,8 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): ...@@ -220,7 +220,8 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return reqs return reqs
def _extend(reqs, model_runner): @torch.no_grad
def extend(reqs, model_runner):
batch = ScheduleBatch.init_new( batch = ScheduleBatch.init_new(
reqs=reqs, reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
...@@ -236,15 +237,8 @@ def _extend(reqs, model_runner): ...@@ -236,15 +237,8 @@ def _extend(reqs, model_runner):
return next_token_ids, logits_output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
def extend(reqs, model_runner): @torch.no_grad
# Disable inference mode for now when torch TP is applied. We can remove def decode(input_token_ids, batch, model_runner):
# this workaround once DTensor adds support for inference mode.
use_inf_mode = not model_runner.torch_tp_applied
with torch.inference_mode(use_inf_mode):
return _extend(reqs, model_runner)
def _decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids batch.output_ids = input_token_ids
batch.prepare_for_decode() batch.prepare_for_decode()
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
...@@ -254,14 +248,6 @@ def _decode(input_token_ids, batch, model_runner): ...@@ -254,14 +248,6 @@ def _decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
def decode(input_token_ids, batch, model_runner):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode = not model_runner.torch_tp_applied
with torch.inference_mode(use_inf_mode):
return _decode(input_token_ids, batch, model_runner)
def correctness_test( def correctness_test(
server_args, server_args,
port_args, port_args,
......
...@@ -87,9 +87,12 @@ class OutlinesGrammar(BaseGrammarObject): ...@@ -87,9 +87,12 @@ class OutlinesGrammar(BaseGrammarObject):
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
).to(vocab_mask.device, non_blocking=True)
vocab_mask = vocab_mask[idx] vocab_mask = vocab_mask[idx]
vocab_mask.fill_(1) vocab_mask.fill_(1)
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0 vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
@staticmethod @staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor): def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
......
...@@ -899,10 +899,7 @@ class ScheduleBatch: ...@@ -899,10 +899,7 @@ class ScheduleBatch:
self.input_ids = self.output_ids self.input_ids = self.output_ids
self.output_ids = None self.output_ids = None
if self.sampling_info.penalizer_orchestrator: self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
self.input_ids
)
# Alloc mem # Alloc mem
bs = len(self.reqs) bs = len(self.reqs)
......
...@@ -30,7 +30,7 @@ import torch ...@@ -30,7 +30,7 @@ import torch
import zmq import zmq
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -102,7 +102,7 @@ class Scheduler: ...@@ -102,7 +102,7 @@ class Scheduler:
self.disable_jump_forward = server_args.disable_jump_forward self.disable_jump_forward = server_args.disable_jump_forward
self.lora_paths = server_args.lora_paths self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
...@@ -159,6 +159,23 @@ class Scheduler: ...@@ -159,6 +159,23 @@ class Scheduler:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
) )
# Check whether overlap can be enabled
if not self.is_generation:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")
if (
server_args.attention_backend == "triton"
or server_args.enable_double_sparsity
or (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
)
):
self.enable_overlap = False
logger.info(
"Overlap scheduler is disabled if using triton attention backend."
)
# Launch a tensor parallel worker # Launch a tensor parallel worker
if self.enable_overlap: if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient TpWorkerClass = TpModelWorkerClient
...@@ -903,6 +920,7 @@ class Scheduler: ...@@ -903,6 +920,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first(): elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result):
...@@ -958,6 +976,7 @@ class Scheduler: ...@@ -958,6 +976,7 @@ class Scheduler:
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model else: # embedding or reward model
...@@ -1031,6 +1050,7 @@ class Scheduler: ...@@ -1031,6 +1050,7 @@ class Scheduler:
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs) self.stream_output(batch.reqs)
......
...@@ -157,14 +157,19 @@ class TpModelWorkerClient: ...@@ -157,14 +157,19 @@ class TpModelWorkerClient:
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# A cuda stream sync here to avoid the cuda illegal memory access error. # A cuda stream sync here to avoid the cuda illegal memory access error.
_ = model_worker_batch.seq_lens[0].item() torch.cuda.current_stream().synchronize()
# Push a new batch to the queue # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
model_worker_batch.sampling_info = dataclasses.replace( sampling_info = model_worker_batch.sampling_info
model_worker_batch.sampling_info, sampling_info.update_penalties()
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
sampling_info,
sampling_info_done=threading.Event(), sampling_info_done=threading.Event(),
scaling_penalties=sampling_info.scaling_penalties,
linear_penalties=sampling_info.linear_penalties,
) )
self.cur_sampling_info = model_worker_batch.sampling_info
# Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
# Allocate output future objects # Allocate output future objects
......
...@@ -116,7 +116,7 @@ class ModelRunner: ...@@ -116,7 +116,7 @@ class ModelRunner:
) )
if self.is_multimodal: if self.is_multimodal:
logger.warning( logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
) )
server_args.chunked_prefill_size = None server_args.chunked_prefill_size = None
...@@ -636,13 +636,11 @@ class ModelRunner: ...@@ -636,13 +636,11 @@ class ModelRunner:
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor: ) -> torch.Tensor:
sampling_info = forward_batch.sampling_info sampling_info = forward_batch.sampling_info
if sampling_info.sampling_info_done: if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed # Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch. # in process_batch_result of the last batch.
if sampling_info.grammars: if sampling_info.grammars:
sampling_info.sampling_info_done.wait() sampling_info.sampling_info_done.wait()
sampling_info.update_penalties()
else: else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask() sampling_info.update_regex_vocab_mask()
......
...@@ -132,9 +132,6 @@ class SamplingBatchInfo: ...@@ -132,9 +132,6 @@ class SamplingBatchInfo:
return len(self.temperatures) return len(self.temperatures)
def update_penalties(self): def update_penalties(self):
if not self.penalizer_orchestrator:
return
self.scaling_penalties = None self.scaling_penalties = None
self.linear_penalties = None self.linear_penalties = None
...@@ -176,8 +173,7 @@ class SamplingBatchInfo: ...@@ -176,8 +173,7 @@ class SamplingBatchInfo:
grammar.fill_vocab_mask(self.vocab_mask, i) grammar.fill_vocab_mask(self.vocab_mask, i)
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
if self.penalizer_orchestrator: self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
for item in [ for item in [
"temperatures", "temperatures",
...@@ -216,8 +212,7 @@ class SamplingBatchInfo: ...@@ -216,8 +212,7 @@ class SamplingBatchInfo:
return None return None
def merge_batch(self, other: "SamplingBatchInfo"): def merge_batch(self, other: "SamplingBatchInfo"):
if self.penalizer_orchestrator: self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [ for item in [
"temperatures", "temperatures",
......
...@@ -123,7 +123,7 @@ class ServerArgs: ...@@ -123,7 +123,7 @@ class ServerArgs:
disable_disk_cache: bool = False disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
disable_mla: bool = False disable_mla: bool = False
enable_overlap_schedule: bool = False disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_dp_attention: bool = False enable_dp_attention: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
...@@ -172,9 +172,7 @@ class ServerArgs: ...@@ -172,9 +172,7 @@ class ServerArgs:
if gpu_mem < 25000: if gpu_mem < 25000:
self.chunked_prefill_size //= 4 # make it 2048 self.chunked_prefill_size //= 4 # make it 2048
self.cuda_graph_max_bs = 4 self.cuda_graph_max_bs = 4
logger.warning( logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
"Automatically adjust --chunked-prefill-size for small GPUs."
)
if not is_flashinfer_available(): if not is_flashinfer_available():
self.attention_backend = "triton" self.attention_backend = "triton"
...@@ -192,15 +190,22 @@ class ServerArgs: ...@@ -192,15 +190,22 @@ class ServerArgs:
self.chunked_prefill_size = self.chunked_prefill_size // 2 self.chunked_prefill_size = self.chunked_prefill_size // 2
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96) self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.schedule_conservativeness = self.schedule_conservativeness * 0.3
self.enable_overlap_schedule = False self.disable_overlap_schedule = True
logger.warning( logger.info(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. " f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. " f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
"Data parallel size is adjusted to be the same as tensor parallel size." "Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap schedule is disabled."
)
if self.enable_mixed_chunk:
logger.info(
"Overlap schedule is disabled because mixed-style chunked prefill is enabled."
) )
self.disable_overlap_schedule = True
if self.enable_overlap_schedule: if not self.disable_overlap_schedule:
self.disable_jump_forward = True self.disable_jump_forward = True
@staticmethod @staticmethod
...@@ -624,9 +629,9 @@ class ServerArgs: ...@@ -624,9 +629,9 @@ class ServerArgs:
help="Disable the NaN detection for better performance.", help="Disable the NaN detection for better performance.",
) )
parser.add_argument( parser.add_argument(
"--enable-overlap-schedule", "--disable-overlap-schedule",
action="store_true", action="store_true",
help="Overlap the CPU scheduler with GPU model worker. Experimental feature.", help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
) )
parser.add_argument( parser.add_argument(
"--enable-mixed-chunk", "--enable-mixed-chunk",
...@@ -692,6 +697,11 @@ class ServerArgs: ...@@ -692,6 +697,11 @@ class ServerArgs:
) )
# Deprecated arguments # Deprecated arguments
parser.add_argument(
"--enable-overlap-schedule",
action=DeprecatedAction,
help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
)
parser.add_argument( parser.add_argument(
"--disable-flashinfer", "--disable-flashinfer",
action=DeprecatedAction, action=DeprecatedAction,
......
...@@ -670,7 +670,7 @@ def run_and_check_memory_leak( ...@@ -670,7 +670,7 @@ def run_and_check_memory_leak(
workload_func, workload_func,
disable_radix_cache, disable_radix_cache,
enable_mixed_chunk, enable_mixed_chunk,
enable_overlap, disable_overlap,
chunked_prefill_size, chunked_prefill_size,
): ):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
...@@ -678,8 +678,8 @@ def run_and_check_memory_leak( ...@@ -678,8 +678,8 @@ def run_and_check_memory_leak(
other_args += ["--disable-radix-cache"] other_args += ["--disable-radix-cache"]
if enable_mixed_chunk: if enable_mixed_chunk:
other_args += ["--enable-mixed-chunk"] other_args += ["--enable-mixed-chunk"]
if enable_overlap: if disable_overlap:
other_args += ["--enable-overlap-schedule"] other_args += ["--disable-overlap-schedule"]
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
port = random.randint(4000, 5000) port = random.randint(4000, 5000)
...@@ -731,7 +731,7 @@ def run_and_check_memory_leak( ...@@ -731,7 +731,7 @@ def run_and_check_memory_leak(
def run_mmlu_test( def run_mmlu_test(
disable_radix_cache=False, disable_radix_cache=False,
enable_mixed_chunk=False, enable_mixed_chunk=False,
enable_overlap=False, disable_overlap=False,
chunked_prefill_size=32, chunked_prefill_size=32,
): ):
def workload_func(base_url, model): def workload_func(base_url, model):
...@@ -754,7 +754,7 @@ def run_mmlu_test( ...@@ -754,7 +754,7 @@ def run_mmlu_test(
workload_func, workload_func,
disable_radix_cache, disable_radix_cache,
enable_mixed_chunk, enable_mixed_chunk,
enable_overlap, disable_overlap,
chunked_prefill_size, chunked_prefill_size,
) )
......
...@@ -17,8 +17,8 @@ suites = { ...@@ -17,8 +17,8 @@ suites = {
"test_json_constrained.py", "test_json_constrained.py",
"test_large_max_new_tokens.py", "test_large_max_new_tokens.py",
"test_metrics.py", "test_metrics.py",
"test_non_overlap_scheduler.py",
"test_openai_server.py", "test_openai_server.py",
"test_overlap_schedule.py",
"test_pytorch_sampling_backend.py", "test_pytorch_sampling_backend.py",
"test_radix_attention.py", "test_radix_attention.py",
"test_retract_decode.py", "test_retract_decode.py",
......
...@@ -97,8 +97,8 @@ class TestBenchServing(unittest.TestCase): ...@@ -97,8 +97,8 @@ class TestBenchServing(unittest.TestCase):
if is_in_ci(): if is_in_ci():
self.assertLess(res["median_e2e_latency_ms"], 12000) self.assertLess(res["median_e2e_latency_ms"], 12000)
self.assertLess(res["median_ttft_ms"], 80) self.assertLess(res["median_ttft_ms"], 86)
self.assertLess(res["median_itl_ms"], 11) self.assertLess(res["median_itl_ms"], 10)
def test_moe_offline_throughput_default(self): def test_moe_offline_throughput_default(self):
res = run_bench_serving( res = run_bench_serving(
......
...@@ -78,10 +78,11 @@ class TestJSONConstrained(unittest.TestCase): ...@@ -78,10 +78,11 @@ class TestJSONConstrained(unittest.TestCase):
self.assertIsInstance(js_obj["population"], int) self.assertIsInstance(js_obj["population"], int)
# Make sure jump forward is triggered # Make sure jump forward is triggered
self.assertGreater( # NOTE: This is skipped because overlap scheduler does not support jump forward
ret["meta_info"]["completion_tokens"], # self.assertGreater(
ret["meta_info"]["completion_tokens_wo_jump_forward"], # ret["meta_info"]["completion_tokens"],
) # ret["meta_info"]["completion_tokens_wo_jump_forward"],
# )
def test_json_generate(self): def test_json_generate(self):
self.run_decode(json_schema=self.json_schema) self.run_decode(json_schema=self.json_schema)
......
...@@ -59,7 +59,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase): ...@@ -59,7 +59,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.41) self.assertGreater(metrics["score"], 0.40)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(
......
...@@ -12,22 +12,22 @@ from sglang.test.test_utils import run_mmlu_test ...@@ -12,22 +12,22 @@ from sglang.test.test_utils import run_mmlu_test
class TestOverlapSchedule(unittest.TestCase): class TestOverlapSchedule(unittest.TestCase):
def test_no_radix_attention_chunked_prefill(self): def test_no_radix_attention_chunked_prefill(self):
run_mmlu_test( run_mmlu_test(
disable_radix_cache=True, chunked_prefill_size=32, enable_overlap=True disable_radix_cache=True, chunked_prefill_size=32, disable_overlap=True
) )
def test_no_radix_attention_no_chunked_prefill(self): def test_no_radix_attention_no_chunked_prefill(self):
run_mmlu_test( run_mmlu_test(
disable_radix_cache=True, chunked_prefill_size=-1, enable_overlap=True disable_radix_cache=True, chunked_prefill_size=-1, disable_overlap=True
) )
def test_radix_attention_chunked_prefill(self): def test_radix_attention_chunked_prefill(self):
run_mmlu_test( run_mmlu_test(
disable_radix_cache=False, chunked_prefill_size=32, enable_overlap=True disable_radix_cache=False, chunked_prefill_size=32, disable_overlap=True
) )
def test_radix_attention_no_chunked_prefill(self): def test_radix_attention_no_chunked_prefill(self):
run_mmlu_test( run_mmlu_test(
disable_radix_cache=False, chunked_prefill_size=-1, enable_overlap=True disable_radix_cache=False, chunked_prefill_size=-1, disable_overlap=True
) )
......
...@@ -107,7 +107,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS): ...@@ -107,7 +107,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
) )
class TestRadixCacheOverlapLPM(TestRadixCacheFCFS): class TestRadixCacheNonOverlapLPM(TestRadixCacheFCFS):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...@@ -117,7 +117,7 @@ class TestRadixCacheOverlapLPM(TestRadixCacheFCFS): ...@@ -117,7 +117,7 @@ class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--enable-overlap-schedule", "--disable-overlap-schedule",
"--chunked-prefill-size", "--chunked-prefill-size",
"128", "128",
"--max-total-tokens", "--max-total-tokens",
......
import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
...@@ -56,14 +57,14 @@ class TestTorchCompile(unittest.TestCase): ...@@ -56,14 +57,14 @@ class TestTorchCompile(unittest.TestCase):
return response.json() return response.json()
def test_throughput(self): def test_throughput(self):
import time # Warmup
res = self.run_decode(16)
max_tokens = 256 max_tokens = 256
tic = time.time() tic = time.time()
res = self.run_decode(max_tokens) res = self.run_decode(max_tokens)
tok = time.time() tok = time.time()
print(res["text"]) print(f"{res=}")
throughput = max_tokens / (tok - tic) throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s") print(f"Throughput: {throughput} tokens/s")
self.assertGreaterEqual(throughput, 152) self.assertGreaterEqual(throughput, 152)
......
import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
...@@ -56,10 +57,10 @@ class TestTorchCompile(unittest.TestCase): ...@@ -56,10 +57,10 @@ class TestTorchCompile(unittest.TestCase):
return response.json() return response.json()
def test_throughput(self): def test_throughput(self):
import time # Warmup
res = self.run_decode(16)
max_tokens = 256 max_tokens = 256
tic = time.time() tic = time.time()
res = self.run_decode(max_tokens) res = self.run_decode(max_tokens)
tok = time.time() tok = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment