Unverified Commit ad82bac6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix model loading & format code (#125)

parent 71b54eea
...@@ -63,7 +63,9 @@ class Req: ...@@ -63,7 +63,9 @@ class Req:
# FIXME: This logic does not really solve the problem of determining whether # FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space. # there should be a leading space.
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0]) first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
first_token = first_token.decode() if isinstance(first_token, bytes) else first_token first_token = (
first_token.decode() if isinstance(first_token, bytes) else first_token
)
if first_token.startswith("▁"): if first_token.startswith("▁"):
old_output_str = " " + old_output_str old_output_str = " " + old_output_str
new_input_string = ( new_input_string = (
......
...@@ -344,9 +344,13 @@ class ModelRpcServer(rpyc.Service): ...@@ -344,9 +344,13 @@ class ModelRpcServer(rpyc.Service):
return None return None
if self.tp_rank == 0: if self.tp_rank == 0:
running_req = 0 if self.running_batch is None else len(self.running_batch.reqs) running_req = (
0 if self.running_batch is None else len(self.running_batch.reqs)
)
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list) hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += (hit_tokens + new_batch_input_tokens) / 10**9 self.tree_cache_metrics["total"] += (
hit_tokens + new_batch_input_tokens
) / 10**9
self.tree_cache_metrics["hit"] += hit_tokens / 10**9 self.tree_cache_metrics["hit"] += hit_tokens / 10**9
tree_cache_hit_rate = ( tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
...@@ -584,7 +588,7 @@ def start_model_process(port): ...@@ -584,7 +588,7 @@ def start_model_process(port):
t = ThreadedServer( t = ThreadedServer(
ModelRpcServer(), ModelRpcServer(),
port=port, port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 600}, protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
) )
t.start() t.start()
...@@ -598,7 +602,7 @@ def start_model_process(port): ...@@ -598,7 +602,7 @@ def start_model_process(port):
con = rpyc.connect( con = rpyc.connect(
"localhost", "localhost",
port, port,
config={"allow_pickle": True, "sync_request_timeout": 600}, config={"allow_pickle": True, "sync_request_timeout": 1800},
) )
break break
except ConnectionRefusedError: except ConnectionRefusedError:
......
...@@ -351,7 +351,11 @@ class MixtralForCausalLM(nn.Module): ...@@ -351,7 +351,11 @@ class MixtralForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision model_name_or_path,
cache_dir,
load_format,
revision,
fall_back_to_pt=False,
): ):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment