"git@developer.sourcefind.cn:change/sglang.git" did not exist on "c7f254468fcae6a2df91765f9229aeeaf1d53613"
Unverified Commit 18108abe authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Fix code style (#2311)

parent c54bda30
...@@ -25,7 +25,6 @@ import uuid ...@@ -25,7 +25,6 @@ import uuid
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import torch
import uvloop import uvloop
import zmq import zmq
import zmq.asyncio import zmq.asyncio
...@@ -337,6 +336,12 @@ class TokenizerManager: ...@@ -337,6 +336,12 @@ class TokenizerManager:
rids.append(tmp_obj.rid) rids.append(tmp_obj.rid)
else: else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal. # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128:
logger.warning(
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
"The performance might be better if you just duplicate the requests n times or use "
"many threads to send them one by one with parallel sampling (n > 1)."
)
# Tokenize all requests # Tokenize all requests
objs = [obj[i] for i in range(batch_size)] objs = [obj[i] for i in range(batch_size)]
...@@ -494,9 +499,7 @@ class TokenizerManager: ...@@ -494,9 +499,7 @@ class TokenizerManager:
result = await self.parameter_update_result result = await self.parameter_update_result
return result.success, result.message return result.success, result.message
else: else:
logger.error( logger.error("Another parameter update is in progress in tokenizer manager")
f"Another parameter update is in progress in tokenizer manager"
)
return ( return (
False, False,
"Another parameter update is in progress. Please try again later.", "Another parameter update is in progress. Please try again later.",
...@@ -597,7 +600,68 @@ class TokenizerManager: ...@@ -597,7 +600,68 @@ class TokenizerManager:
InitWeightsUpdateGroupReqOutput, InitWeightsUpdateGroupReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj() ] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput): if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
recv_obj.meta_info[i]["id"] = rid
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time)
/ (completion_tokens - 1)
)
if state.finished:
self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"]
)
self.metrics_collector.inc_generation_tokens(
completion_tokens
)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
if completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time)
/ completion_tokens
)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj) self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1 else: # self.server_args.dp_size > 1
...@@ -605,13 +669,16 @@ class TokenizerManager: ...@@ -605,13 +669,16 @@ class TokenizerManager:
# set future if the all results are recevied # set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size: if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp) self.model_update_result.set_result(self.model_update_tmp)
continue elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput): elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
assert ( assert (
self.server_args.dp_size == 1 self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed" ), "dp_size must be 1 for update weights from distributed"
self.parameter_update_result.set_result(recv_obj) self.parameter_update_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, GetWeightsByNameReqOutput): elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj) self.get_weights_by_name_result.set_result(recv_obj)
...@@ -621,76 +688,8 @@ class TokenizerManager: ...@@ -621,76 +688,8 @@ class TokenizerManager:
self.get_weights_by_name_result.set_result( self.get_weights_by_name_result.set_result(
self.get_weights_by_name_tmp self.get_weights_by_name_tmp
) )
continue else:
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput): raise ValueError(f"Invalid object: {recv_obj=}")
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
continue
assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
), f"Unexpected obj received: {type(recv_obj)}"
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
recv_obj.meta_info[i]["id"] = rid
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time)
/ (completion_tokens - 1)
)
if state.finished:
self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"]
)
self.metrics_collector.inc_generation_tokens(completion_tokens)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
if completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time) / completion_tokens
)
def convert_logprob_style( def convert_logprob_style(
self, self,
......
...@@ -218,16 +218,6 @@ class ModelRunner: ...@@ -218,16 +218,6 @@ class ModelRunner:
) )
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if self.device == "cuda" and not all(
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism # Check memory for tensor parallelism
if self.tp_size > 1: if self.tp_size > 1:
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
......
...@@ -82,7 +82,6 @@ from sglang.srt.utils import ( ...@@ -82,7 +82,6 @@ from sglang.srt.utils import (
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
delete_directory, delete_directory,
init_custom_process_group,
is_port_available, is_port_available,
kill_process_tree, kill_process_tree,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
...@@ -154,13 +153,11 @@ async def get_model_info(): ...@@ -154,13 +153,11 @@ async def get_model_info():
@app.get("/get_server_info") @app.get("/get_server_info")
async def get_server_info(): async def get_server_info():
try: return {
return await _get_server_info() **dataclasses.asdict(tokenizer_manager.server_args), # server args
**scheduler_info,
except Exception as e: "version": __version__,
return ORJSONResponse( }
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.post("/flush_cache") @app.post("/flush_cache")
...@@ -567,14 +564,6 @@ def launch_server( ...@@ -567,14 +564,6 @@ def launch_server(
t.join() t.join()
async def _get_server_info():
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
**scheduler_info,
"version": __version__,
}
def _set_envs_and_config(server_args: ServerArgs): def _set_envs_and_config(server_args: ServerArgs):
# Set global environments # Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
...@@ -687,160 +676,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -687,160 +676,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
delete_directory(server_args.model_path) delete_directory(server_args.model_path)
class Runtime:
"""
A wrapper for the server.
This is used for launching the server in a python program without
using the commond line interface.
"""
def __init__(
self,
log_level: str = "error",
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# Pre-allocate ports
for port in range(10000, 40000):
if is_port_available(port):
break
port += 1
self.server_args.port = port
self.url = self.server_args.url()
self.generate_url = self.url + "/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(
target=launch_server,
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
try:
init_state = pipe_reader.recv()
except EOFError:
init_state = ""
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
)
async def async_generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
):
if self.server_args.skip_tokenizer_init:
json_data = {
"input_ids": prompt,
"sampling_params": sampling_params,
"stream": True,
}
else:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
else:
yield data
add_request = async_generate
def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {"text": prompt}
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()
STREAM_END_SYMBOL = b"data: [DONE]" STREAM_END_SYMBOL = b"data: [DONE]"
STREAM_CHUNK_START_SYMBOL = b"data:" STREAM_CHUNK_START_SYMBOL = b"data:"
...@@ -854,6 +689,8 @@ class Engine: ...@@ -854,6 +689,8 @@ class Engine:
""" """
def __init__(self, log_level: str = "error", *args, **kwargs): def __init__(self, log_level: str = "error", *args, **kwargs):
"""See the arguments in server_args.py::ServerArgs"""
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown) atexit.register(self.shutdown)
...@@ -986,8 +823,12 @@ class Engine: ...@@ -986,8 +823,12 @@ class Engine:
def stop_profile(self): def stop_profile(self):
tokenizer_manager.stop_profile() tokenizer_manager.stop_profile()
async def get_server_info(self): def get_server_info(self):
return await _get_server_info() return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
**scheduler_info,
"version": __version__,
}
def init_weights_update_group( def init_weights_update_group(
self, self,
...@@ -1037,3 +878,160 @@ class Engine: ...@@ -1037,3 +878,160 @@ class Engine:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(_get_weights()) return loop.run_until_complete(_get_weights())
class Runtime:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the commond line interface.
It is mainly used for the frontend language.
You should use the Engine class if you want to do normal offline processing.
"""
def __init__(
self,
log_level: str = "error",
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# Pre-allocate ports
for port in range(10000, 40000):
if is_port_available(port):
break
port += 1
self.server_args.port = port
self.url = self.server_args.url()
self.generate_url = self.url + "/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(
target=launch_server,
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
try:
init_state = pipe_reader.recv()
except EOFError:
init_state = ""
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
)
async def async_generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
):
if self.server_args.skip_tokenizer_init:
json_data = {
"input_ids": prompt,
"sampling_params": sampling_params,
"stream": True,
}
else:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
else:
yield data
add_request = async_generate
def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {"text": prompt}
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()
...@@ -67,7 +67,7 @@ class TestGetWeightsByName(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestGetWeightsByName(unittest.TestCase):
terminate_process(self.process) terminate_process(self.process)
def assert_tie_word_embeddings(self, truncate_size): def assert_tie_word_embeddings(self, truncate_size):
print(f"assert_tie_word_embeddings") print("assert_tie_word_embeddings")
if self.backend == "Engine": if self.backend == "Engine":
backend_ret = _process_return( backend_ret = _process_return(
self.engine.get_weights_by_name("lm_head.weight", truncate_size) self.engine.get_weights_by_name("lm_head.weight", truncate_size)
...@@ -79,7 +79,7 @@ class TestGetWeightsByName(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestGetWeightsByName(unittest.TestCase):
json={"name": "lm_head.weight", "truncate_size": truncate_size}, json={"name": "lm_head.weight", "truncate_size": truncate_size},
).json() ).json()
) )
print(f"assert_tie_word_embeddings of hf and backend") print("assert_tie_word_embeddings of hf and backend")
assert np.allclose( assert np.allclose(
self.hf_model.get_parameter("model.embed_tokens.weight") self.hf_model.get_parameter("model.embed_tokens.weight")
.cpu() .cpu()
......
...@@ -127,7 +127,7 @@ def init_process_hf( ...@@ -127,7 +127,7 @@ def init_process_hf(
hf_instruct_params = [] hf_instruct_params = []
hf_base_params = [] hf_base_params = []
print(f"get parameter in hf instruct model and base model") print("get parameter in hf instruct model and base model")
for parameter_name in checking_parameters: for parameter_name in checking_parameters:
hf_instruct_params.append( hf_instruct_params.append(
hf_instruct_model.get_parameter(parameter_name)[:truncate_size] hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
...@@ -186,7 +186,6 @@ def init_process_hf( ...@@ -186,7 +186,6 @@ def init_process_hf(
param_queue.put(("broadcast_time", broadcast_time)) param_queue.put(("broadcast_time", broadcast_time))
# Delete the huggingface models to free up memory. # Delete the huggingface models to free up memory.
del hf_instruct_model del hf_instruct_model
del hf_base_model del hf_base_model
gc.collect() gc.collect()
...@@ -238,7 +237,6 @@ def init_process_sgl( ...@@ -238,7 +237,6 @@ def init_process_sgl(
print(f"rank {rank} init server on url: {url}") print(f"rank {rank} init server on url: {url}")
# Get weights of instruct model, i.e. pre-training weights. # Get weights of instruct model, i.e. pre-training weights.
instruct_params = [] instruct_params = []
for parameter_name in checking_parameters: for parameter_name in checking_parameters:
instruct_params.append( instruct_params.append(
...@@ -253,7 +251,6 @@ def init_process_sgl( ...@@ -253,7 +251,6 @@ def init_process_sgl(
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params)) param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
# Init weight update group with the training engine. # Init weight update group with the training engine.
if backend == "Engine": if backend == "Engine":
engine.init_weights_update_group( engine.init_weights_update_group(
master_address="localhost", master_address="localhost",
...@@ -282,7 +279,6 @@ def init_process_sgl( ...@@ -282,7 +279,6 @@ def init_process_sgl(
# The last parameter is lm_head.weight, which is tied # The last parameter is lm_head.weight, which is tied
# with embed_tokens.weight. Actually, we only need # with embed_tokens.weight. Actually, we only need
# to update embed_tokens.weight once. # to update embed_tokens.weight once.
tie_word_embeddings = ( tie_word_embeddings = (
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
) )
...@@ -291,7 +287,6 @@ def init_process_sgl( ...@@ -291,7 +287,6 @@ def init_process_sgl(
update_parameters.remove("lm_head.weight") update_parameters.remove("lm_head.weight")
# Get weights from the training engine and update the inference engine. # Get weights from the training engine and update the inference engine.
for parameter_name in update_parameters: for parameter_name in update_parameters:
if backend == "Engine": if backend == "Engine":
engine.update_weights_from_distributed( engine.update_weights_from_distributed(
...@@ -312,7 +307,6 @@ def init_process_sgl( ...@@ -312,7 +307,6 @@ def init_process_sgl(
time_end_update = time.time() time_end_update = time.time()
# Measure the latency of broadcast/weights update. # Measure the latency of broadcast/weights update.
update_time = time_end_update - time_begin_update update_time = time_end_update - time_begin_update
print( print(
f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s" f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
...@@ -320,7 +314,6 @@ def init_process_sgl( ...@@ -320,7 +314,6 @@ def init_process_sgl(
param_queue.put((f"update_sgl_dp_{rank}_time", update_time)) param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
# Get the weights of post-training model after weights update for correctness check. # Get the weights of post-training model after weights update for correctness check.
base_params = [] base_params = []
for parameter_name in checking_parameters: for parameter_name in checking_parameters:
if backend == "Engine": if backend == "Engine":
...@@ -340,7 +333,6 @@ def init_process_sgl( ...@@ -340,7 +333,6 @@ def init_process_sgl(
param_queue.put((f"sgl_dp_{rank}_base_params", base_params)) param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
# Shutdown the engine or terminate the server process. # Shutdown the engine or terminate the server process.
if backend == "Engine": if backend == "Engine":
engine.shutdown() engine.shutdown()
else: else:
...@@ -426,7 +418,6 @@ def test_update_weights_from_distributed( ...@@ -426,7 +418,6 @@ def test_update_weights_from_distributed(
# Check the correctness of weights update by verifying # Check the correctness of weights update by verifying
# the weights of instruct model and base model. # the weights of instruct model and base model.
for i in range(len(params["hf_instruct"])): for i in range(len(params["hf_instruct"])):
verify_params_close( verify_params_close(
params["hf_instruct"][i], params["hf_instruct"][i],
...@@ -463,7 +454,6 @@ def test_update_weights_from_distributed( ...@@ -463,7 +454,6 @@ def test_update_weights_from_distributed(
), "hf_instruct_params and hf_base_params have different lengths" ), "hf_instruct_params and hf_base_params have different lengths"
# Check if the weights of lm_head are tied with embed_tokens. # Check if the weights of lm_head are tied with embed_tokens.
params_to_check = [ params_to_check = [
( (
params["hf_instruct"], params["hf_instruct"],
...@@ -509,7 +499,6 @@ def test_update_weights_from_distributed( ...@@ -509,7 +499,6 @@ def test_update_weights_from_distributed(
# Time limit for broadcast and update on CI is 3 / 6 # Time limit for broadcast and update on CI is 3 / 6
# On local H100, it's 1 / 2 # On local H100, it's 1 / 2
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6 time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
assert ( assert (
...@@ -526,7 +515,6 @@ def test_update_weights_from_distributed( ...@@ -526,7 +515,6 @@ def test_update_weights_from_distributed(
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s" ), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
# Delete the context and close the parameter queue. # Delete the context and close the parameter queue.
del context del context
param_queue.close() param_queue.close()
param_queue.join_thread() param_queue.join_thread()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment