Unverified Commit 3cf1473a authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Use monotonic clock for interval measurement (#6211)


Signed-off-by: default avatarLifu Huang <lifu.hlf@gmail.com>
parent 27168308
...@@ -109,9 +109,9 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= ...@@ -109,9 +109,9 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=
for video_path in batch_video_files for video_path in batch_video_files
] ]
start_time = time.time() start_time = time.perf_counter()
states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2) states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)
total_time = time.time() - start_time total_time = time.perf_counter() - start_time
average_time = total_time / len(batch_video_files) average_time = total_time / len(batch_video_files)
print( print(
f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds" f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds"
...@@ -240,11 +240,11 @@ if __name__ == "__main__": ...@@ -240,11 +240,11 @@ if __name__ == "__main__":
for f in os.listdir(root) for f in os.listdir(root)
if f.endswith((".mp4", ".avi", ".mov")) if f.endswith((".mp4", ".avi", ".mov"))
] # Add more extensions if needed ] # Add more extensions if needed
start_time = time.time() # Start time for processing a single video start_time = time.perf_counter() # Start time for processing a single video
for cur_video in video_files[:1]: for cur_video in video_files[:1]:
print(cur_video) print(cur_video)
single(cur_video, num_frames) single(cur_video, num_frames)
end_time = time.time() # End time for processing a single video end_time = time.perf_counter() # End time for processing a single video
total_time = end_time - start_time total_time = end_time - start_time
average_time = total_time / len( average_time = total_time / len(
video_files video_files
......
...@@ -89,9 +89,9 @@ def start_server(args, timeout=60): ...@@ -89,9 +89,9 @@ def start_server(args, timeout=60):
process = subprocess.Popen(command, stdout=None, stderr=None) process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.time() start_time = time.perf_counter()
with requests.Session() as session: with requests.Session() as session:
while time.time() - start_time < timeout: while time.perf_counter() - start_time < timeout:
try: try:
# Check the /docs endpoint which FastAPI provides by default # Check the /docs endpoint which FastAPI provides by default
response = session.get( response = session.get(
......
...@@ -150,7 +150,7 @@ def video_stream_request_test(client, video_path): ...@@ -150,7 +150,7 @@ def video_stream_request_test(client, video_path):
def image_speed_test(client): def image_speed_test(client):
print("----------------------Image Speed Test----------------------") print("----------------------Image Speed Test----------------------")
start_time = time.time() start_time = time.perf_counter()
request = client.chat.completions.create( request = client.chat.completions.create(
model="default", model="default",
messages=[ messages=[
...@@ -173,7 +173,7 @@ def image_speed_test(client): ...@@ -173,7 +173,7 @@ def image_speed_test(client):
temperature=0, temperature=0,
max_tokens=1024, max_tokens=1024,
) )
end_time = time.time() end_time = time.perf_counter()
response = request.choices[0].message.content response = request.choices[0].message.content
print(response) print(response)
print("-" * 30) print("-" * 30)
...@@ -184,14 +184,14 @@ def video_speed_test(client, video_path): ...@@ -184,14 +184,14 @@ def video_speed_test(client, video_path):
print("------------------------Video Speed Test------------------------") print("------------------------Video Speed Test------------------------")
messages = prepare_video_messages(video_path) messages = prepare_video_messages(video_path)
start_time = time.time() start_time = time.perf_counter()
video_request = client.chat.completions.create( video_request = client.chat.completions.create(
model="default", model="default",
messages=messages, messages=messages,
temperature=0, temperature=0,
max_tokens=1024, max_tokens=1024,
) )
end_time = time.time() end_time = time.perf_counter()
video_response = video_request.choices[0].message.content video_response = video_request.choices[0].message.content
print(video_response) print(video_response)
print("-" * 30) print("-" * 30)
......
...@@ -373,10 +373,10 @@ def latency_test_run_once( ...@@ -373,10 +373,10 @@ def latency_test_run_once(
# Prefill # Prefill
synchronize(device) synchronize(device)
tic = time.time() tic = time.perf_counter()
next_token_ids, _, batch = extend(reqs, model_runner) next_token_ids, _, batch = extend(reqs, model_runner)
synchronize(device) synchronize(device)
prefill_latency = time.time() - tic prefill_latency = time.perf_counter() - tic
tot_latency += prefill_latency tot_latency += prefill_latency
throughput = input_len * batch_size / prefill_latency throughput = input_len * batch_size / prefill_latency
rank_print( rank_print(
...@@ -389,10 +389,10 @@ def latency_test_run_once( ...@@ -389,10 +389,10 @@ def latency_test_run_once(
decode_latencies = [] decode_latencies = []
for i in range(output_len - 1): for i in range(output_len - 1):
synchronize(device) synchronize(device)
tic = time.time() tic = time.perf_counter()
next_token_ids, _ = decode(next_token_ids, batch, model_runner) next_token_ids, _ = decode(next_token_ids, batch, model_runner)
synchronize(device) synchronize(device)
latency = time.time() - tic latency = time.perf_counter() - tic
tot_latency += latency tot_latency += latency
throughput = batch_size / latency throughput = batch_size / latency
decode_latencies.append(latency) decode_latencies.append(latency)
......
...@@ -92,8 +92,8 @@ def launch_server_process(server_args: ServerArgs): ...@@ -92,8 +92,8 @@ def launch_server_process(server_args: ServerArgs):
base_url = f"http://{server_args.host}:{server_args.port}" base_url = f"http://{server_args.host}:{server_args.port}"
timeout = 600 timeout = 600
start_time = time.time() start_time = time.perf_counter()
while time.time() - start_time < timeout: while time.perf_counter() - start_time < timeout:
try: try:
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
...@@ -141,7 +141,7 @@ def run_one_case( ...@@ -141,7 +141,7 @@ def run_one_case(
else: else:
json_schema = None json_schema = None
tic = time.time() tic = time.perf_counter()
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
json={ json={
...@@ -175,9 +175,9 @@ def run_one_case( ...@@ -175,9 +175,9 @@ def run_one_case(
or data["meta_info"]["finish_reason"]["type"] == "length" or data["meta_info"]["finish_reason"]["type"] == "length"
) )
if data["meta_info"]["completion_tokens"] == 1: if data["meta_info"]["completion_tokens"] == 1:
ttft = time.time() - tic ttft = time.perf_counter() - tic
latency = time.time() - tic latency = time.perf_counter() - tic
input_throughput = batch_size * input_len / ttft input_throughput = batch_size * input_len / ttft
output_throughput = batch_size * output_len / (latency - ttft) output_throughput = batch_size * output_len / (latency - ttft)
overall_throughput = batch_size * (input_len + output_len) / latency overall_throughput = batch_size * (input_len + output_len) / latency
......
...@@ -82,8 +82,8 @@ def launch_server_process_and_send_one_request( ...@@ -82,8 +82,8 @@ def launch_server_process_and_send_one_request(
base_url = f"http://{server_args.host}:{server_args.port}" base_url = f"http://{server_args.host}:{server_args.port}"
timeout = compile_args.timeout timeout = compile_args.timeout
start_time = time.time() start_time = time.perf_counter()
while time.time() - start_time < timeout: while time.perf_counter() - start_time < timeout:
try: try:
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
...@@ -112,9 +112,9 @@ def launch_server_process_and_send_one_request( ...@@ -112,9 +112,9 @@ def launch_server_process_and_send_one_request(
raise RuntimeError(f"Sync request failed: {error}") raise RuntimeError(f"Sync request failed: {error}")
# Other nodes should wait for the exit signal from Rank-0 node. # Other nodes should wait for the exit signal from Rank-0 node.
else: else:
start_time_waiting = time.time() start_time_waiting = time.perf_counter()
while proc.is_alive(): while proc.is_alive():
if time.time() - start_time_waiting < timeout: if time.perf_counter() - start_time_waiting < timeout:
time.sleep(10) time.sleep(10)
else: else:
raise TimeoutError("Waiting for main node timeout!") raise TimeoutError("Waiting for main node timeout!")
......
...@@ -127,14 +127,14 @@ class StatelessProcessGroup: ...@@ -127,14 +127,14 @@ class StatelessProcessGroup:
key = f"send_to/{dst}/{self.send_dst_counter[dst]}" key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj)) self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1 self.send_dst_counter[dst] += 1
self.entries.append((key, time.time())) self.entries.append((key, time.perf_counter()))
def expire_data(self): def expire_data(self):
"""Expire data that is older than `data_expiration_seconds` seconds.""" """Expire data that is older than `data_expiration_seconds` seconds."""
while self.entries: while self.entries:
# check the oldest entry # check the oldest entry
key, timestamp = self.entries[0] key, timestamp = self.entries[0]
if time.time() - timestamp > self.data_expiration_seconds: if time.perf_counter() - timestamp > self.data_expiration_seconds:
self.store.delete_key(key) self.store.delete_key(key)
self.entries.popleft() self.entries.popleft()
else: else:
...@@ -158,7 +158,7 @@ class StatelessProcessGroup: ...@@ -158,7 +158,7 @@ class StatelessProcessGroup:
key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}" key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
self.store.set(key, pickle.dumps(obj)) self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1 self.broadcast_send_counter += 1
self.entries.append((key, time.time())) self.entries.append((key, time.perf_counter()))
return obj return obj
else: else:
key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}" key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
......
...@@ -182,9 +182,9 @@ async def health_generate(request: Request) -> Response: ...@@ -182,9 +182,9 @@ async def health_generate(request: Request) -> Response:
async for _ in _global_state.tokenizer_manager.generate_request(gri, request): async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
break break
tic = time.time() tic = time.perf_counter()
task = asyncio.create_task(gen()) task = asyncio.create_task(gen())
while time.time() < tic + HEALTH_CHECK_TIMEOUT: while time.perf_counter() < tic + HEALTH_CHECK_TIMEOUT:
await asyncio.sleep(1) await asyncio.sleep(1)
if _global_state.tokenizer_manager.last_receive_tstamp > tic: if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel() task.cancel()
......
...@@ -24,10 +24,10 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: ...@@ -24,10 +24,10 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process:
base_url = server_args.url() base_url = server_args.url()
timeout = 300.0 # Increased timeout to 5 minutes for downloading large models timeout = 300.0 # Increased timeout to 5 minutes for downloading large models
start_time = time.time() start_time = time.perf_counter()
with requests.Session() as session: with requests.Session() as session:
while time.time() - start_time < timeout: while time.perf_counter() - start_time < timeout:
try: try:
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
......
...@@ -348,8 +348,8 @@ class Scheduler( ...@@ -348,8 +348,8 @@ class Scheduler(
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.last_decode_stats_tic = time.time() self.last_decode_stats_tic = time.perf_counter()
self.last_prefill_stats_tic = time.time() self.last_prefill_stats_tic = time.perf_counter()
self.return_health_check_ct = 0 self.return_health_check_ct = 0
self.current_stream = torch.get_device_module(self.device).current_stream() self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu": if self.device == "cpu":
...@@ -1032,13 +1032,13 @@ class Scheduler( ...@@ -1032,13 +1032,13 @@ class Scheduler(
add_to_grammar_queue = True add_to_grammar_queue = True
if add_to_grammar_queue: if add_to_grammar_queue:
req.queue_time_start = time.time() req.queue_time_start = time.perf_counter()
self.grammar_queue.append(req) self.grammar_queue.append(req)
else: else:
self._add_request_to_queue(req) self._add_request_to_queue(req)
def _add_request_to_queue(self, req: Req): def _add_request_to_queue(self, req: Req):
req.queue_time_start = time.time() req.queue_time_start = time.perf_counter()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.add(req) self.disagg_prefill_bootstrap_queue.add(req)
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
...@@ -1085,7 +1085,7 @@ class Scheduler( ...@@ -1085,7 +1085,7 @@ class Scheduler(
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
) )
req.queue_time_start = time.time() req.queue_time_start = time.perf_counter()
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
...@@ -1109,8 +1109,8 @@ class Scheduler( ...@@ -1109,8 +1109,8 @@ class Scheduler(
can_run_list: List[Req], can_run_list: List[Req],
running_bs: int, running_bs: int,
): ):
gap_latency = time.time() - self.last_prefill_stats_tic gap_latency = time.perf_counter() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.time() self.last_prefill_stats_tic = time.perf_counter()
self.last_input_throughput = self.num_prefill_tokens / gap_latency self.last_input_throughput = self.num_prefill_tokens / gap_latency
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
...@@ -1160,8 +1160,8 @@ class Scheduler( ...@@ -1160,8 +1160,8 @@ class Scheduler(
): ):
batch = running_batch or self.running_batch batch = running_batch or self.running_batch
gap_latency = time.time() - self.last_decode_stats_tic gap_latency = time.perf_counter() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.time() self.last_decode_stats_tic = time.perf_counter()
self.last_gen_throughput = self.num_generated_tokens / gap_latency self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0 self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs) num_running_reqs = len(batch.reqs)
...@@ -1245,7 +1245,7 @@ class Scheduler( ...@@ -1245,7 +1245,7 @@ class Scheduler(
if ( if (
self.enable_metrics self.enable_metrics
and self.attn_tp_rank == 0 and self.attn_tp_rank == 0
and time.time() > self.metrics_collector.last_log_time + 30 and time.perf_counter() > self.metrics_collector.last_log_time + 30
): ):
# During idle time, also collect metrics every 30 seconds. # During idle time, also collect metrics every 30 seconds.
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
...@@ -1410,7 +1410,7 @@ class Scheduler( ...@@ -1410,7 +1410,7 @@ class Scheduler(
if self.enable_metrics: if self.enable_metrics:
# only record queue time when enable_metrics is True to avoid overhead # only record queue time when enable_metrics is True to avoid overhead
for req in can_run_list: for req in can_run_list:
req.queue_time_end = time.time() req.queue_time_end = time.perf_counter()
self.waiting_queue = [ self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)
...@@ -1783,10 +1783,10 @@ class Scheduler( ...@@ -1783,10 +1783,10 @@ class Scheduler(
def watchdog_thread(self): def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long.""" """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0 self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time() self.watchdog_last_time = time.perf_counter()
while True: while True:
current = time.time() current = time.perf_counter()
if self.cur_batch is not None: if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct: if self.watchdog_last_forward_ct == self.forward_ct:
if current > self.watchdog_last_time + self.watchdog_timeout: if current > self.watchdog_last_time + self.watchdog_timeout:
......
...@@ -335,13 +335,13 @@ class HiRadixCache(RadixCache): ...@@ -335,13 +335,13 @@ class HiRadixCache(RadixCache):
return value, last_node return value, last_node
def _match_prefix_helper(self, node: TreeNode, key: List): def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time() node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key) child_key = self.get_child_key_fn(key)
value = [] value = []
while len(key) > 0 and child_key in node.children.keys(): while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key] child = node.children[child_key]
child.last_access_time = time.time() child.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(child.key, key) prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key): if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len) new_node = self._split_node(child.key, child, prefix_len)
...@@ -386,7 +386,7 @@ class HiRadixCache(RadixCache): ...@@ -386,7 +386,7 @@ class HiRadixCache(RadixCache):
return new_node return new_node
def _insert_helper(self, node: TreeNode, key: List, value): def _insert_helper(self, node: TreeNode, key: List, value):
node.last_access_time = time.time() node.last_access_time = time.monotonic()
if len(key) == 0: if len(key) == 0:
return 0 return 0
...@@ -395,7 +395,7 @@ class HiRadixCache(RadixCache): ...@@ -395,7 +395,7 @@ class HiRadixCache(RadixCache):
while len(key) > 0 and child_key in node.children.keys(): while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key] node = node.children[child_key]
node.last_access_time = time.time() node.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(node.key, key) prefix_len = self.key_match_fn(node.key, key)
if prefix_len == len(node.key): if prefix_len == len(node.key):
......
...@@ -45,7 +45,7 @@ class TreeNode: ...@@ -45,7 +45,7 @@ class TreeNode:
self.key = None self.key = None
self.value = None self.value = None
self.lock_ref = 0 self.lock_ref = 0
self.last_access_time = time.time() self.last_access_time = time.monotonic()
self.hit_count = 0 self.hit_count = 0
# indicating the node is loading KV cache from host # indicating the node is loading KV cache from host
...@@ -322,14 +322,14 @@ class RadixCache(BasePrefixCache): ...@@ -322,14 +322,14 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions ##### ##### Internal Helper Functions #####
def _match_prefix_helper(self, node: TreeNode, key: List): def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time() node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key) child_key = self.get_child_key_fn(key)
value = [] value = []
while len(key) > 0 and child_key in node.children.keys(): while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key] child = node.children[child_key]
child.last_access_time = time.time() child.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(child.key, key) prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key): if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len) new_node = self._split_node(child.key, child, prefix_len)
...@@ -361,7 +361,7 @@ class RadixCache(BasePrefixCache): ...@@ -361,7 +361,7 @@ class RadixCache(BasePrefixCache):
return new_node return new_node
def _insert_helper(self, node: TreeNode, key: List, value): def _insert_helper(self, node: TreeNode, key: List, value):
node.last_access_time = time.time() node.last_access_time = time.monotonic()
if len(key) == 0: if len(key) == 0:
return 0 return 0
...@@ -370,7 +370,7 @@ class RadixCache(BasePrefixCache): ...@@ -370,7 +370,7 @@ class RadixCache(BasePrefixCache):
total_prefix_length = 0 total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys(): while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key] node = node.children[child_key]
node.last_access_time = time.time() node.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(node.key, key) prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len total_prefix_length += prefix_len
key = key[prefix_len:] key = key[prefix_len:]
......
...@@ -154,7 +154,7 @@ class SchedulerMetricsCollector: ...@@ -154,7 +154,7 @@ class SchedulerMetricsCollector:
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
self.labels = labels self.labels = labels
self.last_log_time = time.time() self.last_log_time = time.perf_counter()
self.num_running_reqs = Gauge( self.num_running_reqs = Gauge(
name="sglang:num_running_reqs", name="sglang:num_running_reqs",
...@@ -294,7 +294,7 @@ class SchedulerMetricsCollector: ...@@ -294,7 +294,7 @@ class SchedulerMetricsCollector:
self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs
) )
self.last_log_time = time.time() self.last_log_time = time.perf_counter()
class TokenizerMetricsCollector: class TokenizerMetricsCollector:
......
...@@ -1019,7 +1019,7 @@ class ModelRunner: ...@@ -1019,7 +1019,7 @@ class ModelRunner:
if self.server_args.disable_cuda_graph: if self.server_args.disable_cuda_graph:
return return
tic = time.time() tic = time.perf_counter()
before_mem = get_available_gpu_memory(self.device, self.gpu_id) before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
...@@ -1027,7 +1027,7 @@ class ModelRunner: ...@@ -1027,7 +1027,7 @@ class ModelRunner:
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id) after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. " f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
) )
......
...@@ -228,7 +228,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -228,7 +228,7 @@ class EAGLEWorker(TpModelWorker):
return return
# Capture draft # Capture draft
tic = time.time() tic = time.perf_counter()
before_mem = get_available_gpu_memory(self.device, self.gpu_id) before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
...@@ -236,7 +236,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -236,7 +236,7 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id) after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
) )
# Capture extend # Capture extend
......
...@@ -246,7 +246,7 @@ def mark_start(name, interval=0.1, color=0, indent=0): ...@@ -246,7 +246,7 @@ def mark_start(name, interval=0.1, color=0, indent=0):
torch.cuda.synchronize() torch.cuda.synchronize()
if time_infos.get(name, None) is None: if time_infos.get(name, None) is None:
time_infos[name] = TimeInfo(name, interval, color, indent) time_infos[name] = TimeInfo(name, interval, color, indent)
time_infos[name].acc_time -= time.time() time_infos[name].acc_time -= time.perf_counter()
def mark_end(name): def mark_end(name):
...@@ -254,7 +254,7 @@ def mark_end(name): ...@@ -254,7 +254,7 @@ def mark_end(name):
if not show_time_cost: if not show_time_cost:
return return
torch.cuda.synchronize() torch.cuda.synchronize()
time_infos[name].acc_time += time.time() time_infos[name].acc_time += time.perf_counter()
if time_infos[name].check(): if time_infos[name].check():
time_infos[name].pretty_print() time_infos[name].pretty_print()
...@@ -264,11 +264,11 @@ def calculate_time(show=False, min_cost_ms=0.0): ...@@ -264,11 +264,11 @@ def calculate_time(show=False, min_cost_ms=0.0):
def inner_func(*args, **kwargs): def inner_func(*args, **kwargs):
torch.cuda.synchronize() torch.cuda.synchronize()
if show: if show:
start_time = time.time() start_time = time.perf_counter()
result = func(*args, **kwargs) result = func(*args, **kwargs)
torch.cuda.synchronize() torch.cuda.synchronize()
if show: if show:
cost_time = (time.time() - start_time) * 1000 cost_time = (time.perf_counter() - start_time) * 1000
if cost_time > min_cost_ms: if cost_time > min_cost_ms:
print(f"Function {func.__name__} took {cost_time} ms to run.") print(f"Function {func.__name__} took {cost_time} ms to run.")
return result return result
......
...@@ -526,9 +526,9 @@ def popen_launch_pd_server( ...@@ -526,9 +526,9 @@ def popen_launch_pd_server(
else: else:
process = subprocess.Popen(command, stdout=None, stderr=None, env=env) process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
start_time = time.time() start_time = time.perf_counter()
with requests.Session() as session: with requests.Session() as session:
while time.time() - start_time < timeout: while time.perf_counter() - start_time < timeout:
try: try:
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
......
...@@ -436,7 +436,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None: ...@@ -436,7 +436,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
base_url: The base URL of the server base_url: The base URL of the server
timeout: Maximum time to wait in seconds. None means wait forever. timeout: Maximum time to wait in seconds. None means wait forever.
""" """
start_time = time.time() start_time = time.perf_counter()
while True: while True:
try: try:
response = requests.get( response = requests.get(
...@@ -455,7 +455,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None: ...@@ -455,7 +455,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
) )
break break
if timeout and time.time() - start_time > timeout: if timeout and time.perf_counter() - start_time > timeout:
raise TimeoutError("Server did not become ready within timeout period") raise TimeoutError("Server did not become ready within timeout period")
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
time.sleep(1) time.sleep(1)
......
...@@ -91,10 +91,10 @@ def launch_server_process( ...@@ -91,10 +91,10 @@ def launch_server_process(
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool: def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint.""" """Wait for server to be healthy by checking /health endpoint."""
start_time = time.time() start_time = time.perf_counter()
url = f"http://{host}:{port}/health" url = f"http://{host}:{port}/health"
while time.time() - start_time < timeout: while time.perf_counter() - start_time < timeout:
try: try:
response = requests.get(url, timeout=5) response = requests.get(url, timeout=5)
if response.status_code == 200: if response.status_code == 200:
......
...@@ -97,7 +97,7 @@ class TestDisaggregationMooncake(CustomTestCase): ...@@ -97,7 +97,7 @@ class TestDisaggregationMooncake(CustomTestCase):
@classmethod @classmethod
def wait_server_ready(cls, url, timeout=60): def wait_server_ready(cls, url, timeout=60):
start_time = time.time() start_time = time.perf_counter()
while True: while True:
try: try:
response = requests.get(url) response = requests.get(url)
...@@ -107,7 +107,7 @@ class TestDisaggregationMooncake(CustomTestCase): ...@@ -107,7 +107,7 @@ class TestDisaggregationMooncake(CustomTestCase):
except Exception: except Exception:
pass pass
if time.time() - start_time > timeout: if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s") raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1) time.sleep(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment