Unverified Commit 9fdc6d6a authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Fix the lora adapter when lora path is none (#4799)


Co-authored-by: default avatarBeichen Ma <mabeichen12@gmail.com>
parent 42a45df0
...@@ -133,10 +133,6 @@ class LoRAManager: ...@@ -133,10 +133,6 @@ class LoRAManager:
assert len(cur_uids) <= self.max_loras_per_batch assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras) self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
# FIXME: Handle lora uid with None more safely
if cur_uids == set([None]):
return
# set up batch info shared by all lora moruldes # set up batch info shared by all lora moruldes
bs = forward_batch.batch_size bs = forward_batch.batch_size
seg_lens = ( seg_lens = (
......
...@@ -163,7 +163,7 @@ class LoRAMemoryPool: ...@@ -163,7 +163,7 @@ class LoRAMemoryPool:
if uid is None: if uid is None:
for i in range(self.num_layer): for i in range(self.num_layer):
for k in self.A_buffer.keys(): for k in self.A_buffer.keys():
self.A_buffer[k][i][buffer_id] *= 0 self.A_buffer[k][i][buffer_id] = 0
return return
assert lora_adapter is not None assert lora_adapter is not None
......
...@@ -96,6 +96,11 @@ class TestLoRA(CustomTestCase): ...@@ -96,6 +96,11 @@ class TestLoRA(CustomTestCase):
srt_outputs = srt_runner.forward( srt_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
) )
srt_outputs_lora_path_none = srt_runner.forward(
prompts,
max_new_tokens=max_new_tokens,
lora_paths=[None] * len(prompts),
)
with HFRunner( with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation" base_path, torch_dtype=torch_dtype, model_type="generation"
...@@ -169,18 +174,20 @@ class TestLoRA(CustomTestCase): ...@@ -169,18 +174,20 @@ class TestLoRA(CustomTestCase):
print(f"{srt_outputs.output_strs=}") print(f"{srt_outputs.output_strs=}")
print(f"{hf_no_lora_outputs.output_strs=}") print(f"{hf_no_lora_outputs.output_strs=}")
print(f"{srt_no_lora_outputs.output_strs=}") print(f"{srt_no_lora_outputs.output_strs=}")
print(f"{srt_outputs_lora_path_none.output_strs=}")
for i in range(len(prompts)): for i in range(len(prompts)):
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], ( assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
srt_outputs.output_strs[i].strip(" "), srt_outputs.output_strs[i].strip(" "),
hf_outputs.output_strs[i], hf_outputs.output_strs[i],
) )
# assert ( assert (
# srt_no_lora_outputs.output_strs[i].strip(" ") srt_no_lora_outputs.output_strs[i].strip(" ")
# == hf_no_lora_outputs.output_strs[i] == hf_no_lora_outputs.output_strs[i]
# ), ( ), (
# srt_no_lora_outputs.output_strs[i].strip(" "), srt_no_lora_outputs.output_strs[i].strip(" "),
# hf_no_lora_outputs.output_strs[i], hf_no_lora_outputs.output_strs[i],
# ) )
assert srt_outputs_lora_path_none == srt_no_lora_outputs
def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
print("=================== testing serving =======================") print("=================== testing serving =======================")
...@@ -257,7 +264,7 @@ class TestLoRA(CustomTestCase): ...@@ -257,7 +264,7 @@ class TestLoRA(CustomTestCase):
srt_no_lora_logprobs = torch.Tensor( srt_no_lora_logprobs = torch.Tensor(
srt_no_lora_outputs.top_input_logprobs[i] srt_no_lora_outputs.top_input_logprobs[i]
) )
srt_logprobs = torch.uensor(srt_outputs.top_input_logprobs[i]) srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs))) print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs)))
print(f"{srt_no_lora_outputs.output_strs=}") print(f"{srt_no_lora_outputs.output_strs=}")
...@@ -280,7 +287,7 @@ class TestLoRA(CustomTestCase): ...@@ -280,7 +287,7 @@ class TestLoRA(CustomTestCase):
tp_size = 1 tp_size = 1
max_new_tokens = 32 max_new_tokens = 32
self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens) self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
# self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens) self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
# self.base_inference( # self.base_inference(
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens # PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
# ) # )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment