"torchvision/vscode:/vscode.git/clone" did not exist on "5f0532daea1f4bdf5e2c511361b484cdb002b50b"
Unverified Commit f87a6ab3 authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

Resolves the `404 Not Found` error when running `compile_deep_gemm.py` in multi-node setups (#5720)

parent eebfdb94
...@@ -88,8 +88,36 @@ def launch_server_process_and_send_one_request( ...@@ -88,8 +88,36 @@ def launch_server_process_and_send_one_request(
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
} }
response = requests.get(f"{base_url}/v1/models", headers=headers) if server_args.node_rank == 0:
response = requests.get(f"{base_url}/v1/models", headers=headers)
else:
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
response = requests.get(f"{base_url}/health", headers=headers)
if response.status_code == 200: if response.status_code == 200:
# Rank-0 node send a request to sync with other node and then return.
if server_args.node_rank == 0:
response = requests.post(
f"{base_url}/generate",
json={
"input_ids": [0, 1, 2, 3],
"sampling_params": {
"max_new_tokens": 8,
"temperature": 0,
},
},
timeout=600,
)
if response.status_code != 200:
error = response.json()
raise RuntimeError(f"Sync request failed: {error}")
# Other nodes should wait for the exit signal from Rank-0 node.
else:
start_time_waiting = time.time()
while proc.is_alive():
if time.time() - start_time_waiting < timeout:
time.sleep(10)
else:
raise TimeoutError("Waiting for main node timeout!")
return proc return proc
except requests.RequestException: except requests.RequestException:
pass pass
...@@ -122,10 +150,19 @@ def run_compile(server_args: ServerArgs, compile_args: CompileArgs): ...@@ -122,10 +150,19 @@ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
proc = launch_server_process_and_send_one_request(server_args, compile_args) proc = launch_server_process_and_send_one_request(server_args, compile_args)
kill_process_tree(proc.pid)
print("\nDeepGEMM Kernels compilation finished successfully.") print("\nDeepGEMM Kernels compilation finished successfully.")
# Sleep for safety
time.sleep(10)
if proc.is_alive():
# This is the rank0 node.
kill_process_tree(proc.pid)
else:
try:
kill_process_tree(proc.pid)
except Exception:
pass
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment