"vscode:/vscode.git/clone" did not exist on "28f9d84549c0b1d24ef00d69a4c723f3a11cffb6"
Unverified Commit b7a065ea authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use cuda event wait and synchronization instead of busy waiting (#2089)

parent b1104538
...@@ -1063,7 +1063,7 @@ class ScheduleBatch: ...@@ -1063,7 +1063,7 @@ class ScheduleBatch:
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs, decoding_reqs=self.decoding_reqs,
sampling_info=dataclasses.replace(self.sampling_info), sampling_info=self.sampling_info,
) )
def __str__(self): def __str__(self):
......
...@@ -387,9 +387,6 @@ class Scheduler: ...@@ -387,9 +387,6 @@ class Scheduler:
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_ = batch.seq_lens[0].item()
result = self.run_batch(batch) result = self.run_batch(batch)
result_queue.append((batch.copy(), result)) result_queue.append((batch.copy(), result))
......
...@@ -142,12 +142,12 @@ class TpModelWorker: ...@@ -142,12 +142,12 @@ class TpModelWorker:
def forward_batch_generation( def forward_batch_generation(
self, self,
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
launch_event: Optional[threading.Event] = None, launch_done: Optional[threading.Event] = None,
): ):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
if launch_event: if launch_done:
launch_event.set() launch_done.set()
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
return logits_output, next_token_ids return logits_output, next_token_ids
......
...@@ -96,19 +96,22 @@ class TpModelWorkerClient: ...@@ -96,19 +96,22 @@ class TpModelWorkerClient:
@torch.no_grad() @torch.no_grad()
def forward_thread_func_(self): def forward_thread_func_(self):
while True: while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get() model_worker_batch, future_token_ids_ct, compute_info_done = (
self.input_queue.get()
)
if not model_worker_batch: if not model_worker_batch:
break break
self.launch_event = threading.Event() self.launch_done = threading.Event()
copy_event = torch.cuda.Event() copy_done = torch.cuda.Event()
# Resolve future tokens in the input # Resolve future tokens in the input
input_ids = model_worker_batch.input_ids input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map) resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward # Run forward
compute_info_done.wait()
logits_output, next_token_ids = self.worker.forward_batch_generation( logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch, self.launch_event model_worker_batch, self.launch_done
) )
# Update the future token ids map # Update the future token ids map
...@@ -133,15 +136,14 @@ class TpModelWorkerClient: ...@@ -133,15 +136,14 @@ class TpModelWorkerClient:
) )
) )
next_token_ids = next_token_ids.to("cpu", non_blocking=True) next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event.record() copy_done.record()
self.output_queue.put((copy_event, logits_output, next_token_ids)) self.output_queue.put((copy_done, logits_output, next_token_ids))
def resolve_batch_result(self, bid: int): def resolve_batch_result(self, bid: int):
copy_event, logits_output, next_token_ids = self.output_queue.get() copy_done, logits_output, next_token_ids = self.output_queue.get()
while not copy_event.query(): copy_done.synchronize()
time.sleep(1e-5) self.launch_done.wait()
self.launch_event.wait()
if logits_output.next_token_logprobs is not None: if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = ( logits_output.next_token_logprobs = (
...@@ -162,7 +164,11 @@ class TpModelWorkerClient: ...@@ -162,7 +164,11 @@ class TpModelWorkerClient:
model_worker_batch.sampling_info = dataclasses.replace( model_worker_batch.sampling_info = dataclasses.replace(
model_worker_batch.sampling_info model_worker_batch.sampling_info
) )
self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) compute_info_done = torch.cuda.Event()
compute_info_done.record()
self.input_queue.put(
(model_worker_batch, self.future_token_ids_ct, compute_info_done)
)
# Allocate output future objects # Allocate output future objects
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
......
...@@ -38,7 +38,7 @@ class TestLargeMaxNewTokens(unittest.TestCase): ...@@ -38,7 +38,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
api_key=cls.api_key, api_key=cls.api_key,
other_args=( other_args=(
"--max-total-token", "--max-total-token",
"1024", "1536",
"--context-len", "--context-len",
"8192", "8192",
"--decode-log-interval", "--decode-log-interval",
......
...@@ -29,7 +29,7 @@ class TestSRTEngine(unittest.TestCase): ...@@ -29,7 +29,7 @@ class TestSRTEngine(unittest.TestCase):
sampling_params = {"temperature": 0, "max_new_tokens": 8} sampling_params = {"temperature": 0, "max_new_tokens": 8}
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") engine = sgl.Engine(model_path=model_path, random_seed=42)
out1 = engine.generate(prompt, sampling_params)["text"] out1 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown() engine.shutdown()
...@@ -51,7 +51,7 @@ class TestSRTEngine(unittest.TestCase): ...@@ -51,7 +51,7 @@ class TestSRTEngine(unittest.TestCase):
sampling_params = {"temperature": 0, "max_new_tokens": 8} sampling_params = {"temperature": 0, "max_new_tokens": 8}
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") engine = sgl.Engine(model_path=model_path, random_seed=42)
engine.generate(prompt, sampling_params) engine.generate(prompt, sampling_params)
engine.generate(prompt, sampling_params) engine.generate(prompt, sampling_params)
engine.shutdown() engine.shutdown()
...@@ -74,7 +74,6 @@ class TestSRTEngine(unittest.TestCase): ...@@ -74,7 +74,6 @@ class TestSRTEngine(unittest.TestCase):
# Create an LLM. # Create an LLM.
llm = sgl.Engine( llm = sgl.Engine(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
log_level="error",
) )
# 1. sync + non streaming # 1. sync + non streaming
...@@ -118,7 +117,9 @@ class TestSRTEngine(unittest.TestCase): ...@@ -118,7 +117,9 @@ class TestSRTEngine(unittest.TestCase):
prompt = "The capital of UK is" prompt = "The capital of UK is"
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") engine = sgl.Engine(
model_path=model_path, random_seed=42, disable_radix_cache=True
)
sampling_params = {"temperature": 0, "max_new_tokens": 8} sampling_params = {"temperature": 0, "max_new_tokens": 8}
out1 = engine.generate(prompt, sampling_params)["text"] out1 = engine.generate(prompt, sampling_params)["text"]
...@@ -141,9 +142,7 @@ class TestSRTEngine(unittest.TestCase): ...@@ -141,9 +142,7 @@ class TestSRTEngine(unittest.TestCase):
prompt = "Today is a sunny day and I like" prompt = "Today is a sunny day and I like"
model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
engine = sgl.Engine( engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42)
model_path=model_path, is_embedding=True, random_seed=42, log_level="error"
)
out1 = torch.tensor(engine.encode(prompt)["embedding"]) out1 = torch.tensor(engine.encode(prompt)["embedding"])
engine.shutdown() engine.shutdown()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment