Unverified Commit 4474eaf5 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (#6861)

parent 499f5e62
...@@ -165,14 +165,19 @@ class LoRAAdapter(nn.Module): ...@@ -165,14 +165,19 @@ class LoRAAdapter(nn.Module):
self.base_hf_config.hidden_size self.base_hf_config.hidden_size
// self.base_hf_config.num_attention_heads // self.base_hf_config.num_attention_heads
) )
weights[q_name], weights[kv_name] = torch.split( weights[q_name], k_proj_weight, v_proj_weight = torch.split(
weights[qkv_name], weights[qkv_name],
[ [
head_size * self.base_hf_config.num_attention_heads, head_size * self.base_hf_config.num_attention_heads,
head_size * self.base_hf_config.num_key_value_heads * 2, head_size * self.base_hf_config.num_key_value_heads,
head_size * self.base_hf_config.num_key_value_heads,
], ],
dim=0, dim=0,
) )
weights[kv_name] = torch.stack(
[k_proj_weight, v_proj_weight],
dim=0,
)
def normalize_gate_up_proj( def normalize_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor] self, weight_names: List[str], weights: Dict[str, torch.Tensor]
......
...@@ -157,6 +157,10 @@ class LoRAMemoryPool: ...@@ -157,6 +157,10 @@ class LoRAMemoryPool:
def load_lora_weight_to_buffer( def load_lora_weight_to_buffer(
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
): ):
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
assert (
buffer_view.shape == weight.shape
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
if uid is None: if uid is None:
for i in range(self.num_layer): for i in range(self.num_layer):
...@@ -208,21 +212,27 @@ class LoRAMemoryPool: ...@@ -208,21 +212,27 @@ class LoRAMemoryPool:
for name, weights in temp_A_buffer.items(): for name, weights in temp_A_buffer.items():
c = get_stacked_multiply(name) c = get_stacked_multiply(name)
self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_( buffer_view = self.A_buffer[name][layer_id][buffer_id][
weights : lora_rank * c, :
) ]
check_lora_weight_shape(buffer_view, weights)
buffer_view.copy_(weights)
for name, weights in temp_B_buffer.items(): for name, weights in temp_B_buffer.items():
c = get_stacked_multiply(name) c = get_stacked_multiply(name)
if c > 1: if c > 1:
for stacked_id in range(c): for stacked_id in range(c):
self.B_buffer[name][layer_id][stacked_id][buffer_id][ buffer_view = self.B_buffer[name][layer_id][stacked_id][
:, :lora_rank buffer_id
].copy_(weights[stacked_id]) ][:, :lora_rank]
check_lora_weight_shape(buffer_view, weights[stacked_id])
buffer_view.copy_(weights[stacked_id])
else: else:
self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_( buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
weights :, :lora_rank
) ]
check_lora_weight_shape(buffer_view, weights)
buffer_view.copy_(weights)
def get_tensor( def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType self, weight_name: str, layer_id: int, lora_type: LoRAType
......
...@@ -177,9 +177,19 @@ class TestKimiVLServer(TestOpenAIVisionServer): ...@@ -177,9 +177,19 @@ class TestKimiVLServer(TestOpenAIVisionServer):
class TestPhi4MMServer(TestOpenAIVisionServer): class TestPhi4MMServer(TestOpenAIVisionServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default.
from huggingface_hub import constants, snapshot_download
snapshot_download(
"microsoft/Phi-4-multimodal-instruct",
allow_patterns=["**/adapter_config.json"],
)
cls.model = "microsoft/Phi-4-multimodal-instruct" cls.model = "microsoft/Phi-4-multimodal-instruct"
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456" cls.api_key = "sk-123456"
revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a"
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
...@@ -188,15 +198,27 @@ class TestPhi4MMServer(TestOpenAIVisionServer): ...@@ -188,15 +198,27 @@ class TestPhi4MMServer(TestOpenAIVisionServer):
"--trust-remote-code", "--trust-remote-code",
"--mem-fraction-static", "--mem-fraction-static",
"0.75", "0.75",
"--disable-radix-cache",
"--max-loras-per-batch",
"1",
"--revision",
revision,
"--lora-paths",
f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora",
], ],
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self): def get_request_kwargs(self):
pass return {
"extra_body": {
"lora_path": "vision",
"top_k": 1,
"top_p": 1.0,
}
}
def test_multi_images_chat_completion(self): def test_video_chat_completion(self):
# TODO (lifuhuang): support LoRA to enable Phi4MM multi-image understanding capability.
pass pass
......
import base64 import base64
import copy
import io import io
import json import json
import os import os
...@@ -47,6 +48,9 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -47,6 +48,9 @@ class TestOpenAIVisionServer(CustomTestCase):
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def get_request_kwargs(self):
return {}
def test_single_image_chat_completion(self): def test_single_image_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
...@@ -68,6 +72,7 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -68,6 +72,7 @@ class TestOpenAIVisionServer(CustomTestCase):
}, },
], ],
temperature=0, temperature=0,
**(self.get_request_kwargs()),
) )
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
...@@ -130,6 +135,7 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -130,6 +135,7 @@ class TestOpenAIVisionServer(CustomTestCase):
}, },
], ],
temperature=0, temperature=0,
**(self.get_request_kwargs()),
) )
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
...@@ -172,6 +178,7 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -172,6 +178,7 @@ class TestOpenAIVisionServer(CustomTestCase):
}, },
], ],
temperature=0, temperature=0,
**(self.get_request_kwargs()),
) )
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
...@@ -284,6 +291,7 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -284,6 +291,7 @@ class TestOpenAIVisionServer(CustomTestCase):
temperature=0, temperature=0,
max_tokens=1024, max_tokens=1024,
stream=False, stream=False,
**(self.get_request_kwargs()),
) )
video_response = response.choices[0].message.content video_response = response.choices[0].message.content
...@@ -324,6 +332,9 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -324,6 +332,9 @@ class TestOpenAIVisionServer(CustomTestCase):
+ r"""\}""" + r"""\}"""
) )
extra_kwargs = self.get_request_kwargs()
extra_kwargs.setdefault("extra_body", {})["regex"] = regex
response = client.chat.completions.create( response = client.chat.completions.create(
model="default", model="default",
messages=[ messages=[
...@@ -342,7 +353,7 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -342,7 +353,7 @@ class TestOpenAIVisionServer(CustomTestCase):
}, },
], ],
temperature=0, temperature=0,
extra_body={"regex": regex}, **extra_kwargs,
) )
text = response.choices[0].message.content text = response.choices[0].message.content
...@@ -388,6 +399,7 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -388,6 +399,7 @@ class TestOpenAIVisionServer(CustomTestCase):
{"role": "user", "content": content}, {"role": "user", "content": content},
], ],
temperature=0, temperature=0,
**(self.get_request_kwargs()),
) )
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
...@@ -430,6 +442,7 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -430,6 +442,7 @@ class TestOpenAIVisionServer(CustomTestCase):
temperature=0, temperature=0,
max_tokens=128, max_tokens=128,
stream=False, stream=False,
**(self.get_request_kwargs()),
) )
audio_response = response.choices[0].message.content audio_response = response.choices[0].message.content
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment