"src/array/vscode:/vscode.git/clone" did not exist on "f9ad1c80da2fba12d9745acb3c182de3d3642f14"
Unverified Commit 6b7038ba authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Speedup warmup when DP > 1 (#4695)

parent 57eec0bf
...@@ -730,9 +730,9 @@ def _wait_and_warmup( ...@@ -730,9 +730,9 @@ def _wait_and_warmup(
}, },
} }
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
json_data["input_ids"] = [10, 11, 12] json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)]
else: else:
json_data["text"] = "The capital city of France is" json_data["text"] = ["The capital city of France is"] * server_args.dp_size
# Debug dumping # Debug dumping
if server_args.debug_tensor_dump_input_file: if server_args.debug_tensor_dump_input_file:
...@@ -743,14 +743,13 @@ def _wait_and_warmup( ...@@ -743,14 +743,13 @@ def _wait_and_warmup(
json_data["sampling_params"]["max_new_tokens"] = 0 json_data["sampling_params"]["max_new_tokens"] = 0
try: try:
for i in range(server_args.dp_size): res = requests.post(
res = requests.post( url + request_name,
url + request_name, json=json_data,
json=json_data, headers=headers,
headers=headers, timeout=600,
timeout=600, )
) assert res.status_code == 200, f"{res}"
assert res.status_code == 200, f"{res}"
except Exception: except Exception:
last_traceback = get_exception_traceback() last_traceback = get_exception_traceback()
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment