Unverified Commit 18108abe authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Fix code style (#2311)

parent c54bda30
......@@ -25,7 +25,6 @@ import uuid
from typing import Dict, List, Optional, Tuple, Union
import fastapi
import torch
import uvloop
import zmq
import zmq.asyncio
......@@ -337,6 +336,12 @@ class TokenizerManager:
rids.append(tmp_obj.rid)
else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128:
logger.warning(
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
"The performance might be better if you just duplicate the requests n times or use "
"many threads to send them one by one with parallel sampling (n > 1)."
)
# Tokenize all requests
objs = [obj[i] for i in range(batch_size)]
......@@ -494,9 +499,7 @@ class TokenizerManager:
result = await self.parameter_update_result
return result.success, result.message
else:
logger.error(
f"Another parameter update is in progress in tokenizer manager"
)
logger.error("Another parameter update is in progress in tokenizer manager")
return (
False,
"Another parameter update is in progress. Please try again later.",
......@@ -597,7 +600,68 @@ class TokenizerManager:
InitWeightsUpdateGroupReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
recv_obj.meta_info[i]["id"] = rid
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time)
/ (completion_tokens - 1)
)
if state.finished:
self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"]
)
self.metrics_collector.inc_generation_tokens(
completion_tokens
)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
if completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time)
/ completion_tokens
)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1
......@@ -605,13 +669,16 @@ class TokenizerManager:
# set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.parameter_update_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj)
......@@ -621,76 +688,8 @@ class TokenizerManager:
self.get_weights_by_name_result.set_result(
self.get_weights_by_name_tmp
)
continue
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
continue
assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
), f"Unexpected obj received: {type(recv_obj)}"
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
recv_obj.meta_info[i]["id"] = rid
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time)
/ (completion_tokens - 1)
)
if state.finished:
self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"]
)
self.metrics_collector.inc_generation_tokens(completion_tokens)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
if completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time) / completion_tokens
)
else:
raise ValueError(f"Invalid object: {recv_obj=}")
def convert_logprob_style(
self,
......
......@@ -218,16 +218,6 @@ class ModelRunner:
)
self.tp_group = get_tp_group()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if self.device == "cuda" and not all(
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism
if self.tp_size > 1:
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
......
......@@ -82,7 +82,6 @@ from sglang.srt.utils import (
assert_pkg_version,
configure_logger,
delete_directory,
init_custom_process_group,
is_port_available,
kill_process_tree,
maybe_set_triton_cache_manager,
......@@ -154,13 +153,11 @@ async def get_model_info():
@app.get("/get_server_info")
async def get_server_info():
try:
return await _get_server_info()
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
**scheduler_info,
"version": __version__,
}
@app.post("/flush_cache")
......@@ -567,14 +564,6 @@ def launch_server(
t.join()
async def _get_server_info():
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
**scheduler_info,
"version": __version__,
}
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
......@@ -687,160 +676,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
delete_directory(server_args.model_path)
class Runtime:
"""
A wrapper for the server.
This is used for launching the server in a python program without
using the commond line interface.
"""
def __init__(
self,
log_level: str = "error",
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# Pre-allocate ports
for port in range(10000, 40000):
if is_port_available(port):
break
port += 1
self.server_args.port = port
self.url = self.server_args.url()
self.generate_url = self.url + "/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(
target=launch_server,
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
try:
init_state = pipe_reader.recv()
except EOFError:
init_state = ""
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
)
async def async_generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
):
if self.server_args.skip_tokenizer_init:
json_data = {
"input_ids": prompt,
"sampling_params": sampling_params,
"stream": True,
}
else:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
else:
yield data
add_request = async_generate
def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {"text": prompt}
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()
STREAM_END_SYMBOL = b"data: [DONE]"
STREAM_CHUNK_START_SYMBOL = b"data:"
......@@ -854,6 +689,8 @@ class Engine:
"""
def __init__(self, log_level: str = "error", *args, **kwargs):
"""See the arguments in server_args.py::ServerArgs"""
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
......@@ -986,8 +823,12 @@ class Engine:
def stop_profile(self):
tokenizer_manager.stop_profile()
async def get_server_info(self):
return await _get_server_info()
def get_server_info(self):
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
**scheduler_info,
"version": __version__,
}
def init_weights_update_group(
self,
......@@ -1037,3 +878,160 @@ class Engine:
loop = asyncio.get_event_loop()
return loop.run_until_complete(_get_weights())
class Runtime:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the commond line interface.
It is mainly used for the frontend language.
You should use the Engine class if you want to do normal offline processing.
"""
def __init__(
self,
log_level: str = "error",
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# Pre-allocate ports
for port in range(10000, 40000):
if is_port_available(port):
break
port += 1
self.server_args.port = port
self.url = self.server_args.url()
self.generate_url = self.url + "/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(
target=launch_server,
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
try:
init_state = pipe_reader.recv()
except EOFError:
init_state = ""
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
)
async def async_generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
):
if self.server_args.skip_tokenizer_init:
json_data = {
"input_ids": prompt,
"sampling_params": sampling_params,
"stream": True,
}
else:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
else:
yield data
add_request = async_generate
def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {"text": prompt}
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()
......@@ -67,7 +67,7 @@ class TestGetWeightsByName(unittest.TestCase):
terminate_process(self.process)
def assert_tie_word_embeddings(self, truncate_size):
print(f"assert_tie_word_embeddings")
print("assert_tie_word_embeddings")
if self.backend == "Engine":
backend_ret = _process_return(
self.engine.get_weights_by_name("lm_head.weight", truncate_size)
......@@ -79,7 +79,7 @@ class TestGetWeightsByName(unittest.TestCase):
json={"name": "lm_head.weight", "truncate_size": truncate_size},
).json()
)
print(f"assert_tie_word_embeddings of hf and backend")
print("assert_tie_word_embeddings of hf and backend")
assert np.allclose(
self.hf_model.get_parameter("model.embed_tokens.weight")
.cpu()
......
......@@ -127,7 +127,7 @@ def init_process_hf(
hf_instruct_params = []
hf_base_params = []
print(f"get parameter in hf instruct model and base model")
print("get parameter in hf instruct model and base model")
for parameter_name in checking_parameters:
hf_instruct_params.append(
hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
......@@ -186,7 +186,6 @@ def init_process_hf(
param_queue.put(("broadcast_time", broadcast_time))
# Delete the huggingface models to free up memory.
del hf_instruct_model
del hf_base_model
gc.collect()
......@@ -238,7 +237,6 @@ def init_process_sgl(
print(f"rank {rank} init server on url: {url}")
# Get weights of instruct model, i.e. pre-training weights.
instruct_params = []
for parameter_name in checking_parameters:
instruct_params.append(
......@@ -253,7 +251,6 @@ def init_process_sgl(
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
# Init weight update group with the training engine.
if backend == "Engine":
engine.init_weights_update_group(
master_address="localhost",
......@@ -282,7 +279,6 @@ def init_process_sgl(
# The last parameter is lm_head.weight, which is tied
# with embed_tokens.weight. Actually, we only need
# to update embed_tokens.weight once.
tie_word_embeddings = (
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
)
......@@ -291,7 +287,6 @@ def init_process_sgl(
update_parameters.remove("lm_head.weight")
# Get weights from the training engine and update the inference engine.
for parameter_name in update_parameters:
if backend == "Engine":
engine.update_weights_from_distributed(
......@@ -312,7 +307,6 @@ def init_process_sgl(
time_end_update = time.time()
# Measure the latency of broadcast/weights update.
update_time = time_end_update - time_begin_update
print(
f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
......@@ -320,7 +314,6 @@ def init_process_sgl(
param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
# Get the weights of post-training model after weights update for correctness check.
base_params = []
for parameter_name in checking_parameters:
if backend == "Engine":
......@@ -340,7 +333,6 @@ def init_process_sgl(
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
# Shutdown the engine or terminate the server process.
if backend == "Engine":
engine.shutdown()
else:
......@@ -426,7 +418,6 @@ def test_update_weights_from_distributed(
# Check the correctness of weights update by verifying
# the weights of instruct model and base model.
for i in range(len(params["hf_instruct"])):
verify_params_close(
params["hf_instruct"][i],
......@@ -463,7 +454,6 @@ def test_update_weights_from_distributed(
), "hf_instruct_params and hf_base_params have different lengths"
# Check if the weights of lm_head are tied with embed_tokens.
params_to_check = [
(
params["hf_instruct"],
......@@ -509,7 +499,6 @@ def test_update_weights_from_distributed(
# Time limit for broadcast and update on CI is 3 / 6
# On local H100, it's 1 / 2
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
assert (
......@@ -526,7 +515,6 @@ def test_update_weights_from_distributed(
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
# Delete the context and close the parameter queue.
del context
param_queue.close()
param_queue.join_thread()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment