Unverified Commit 2f5f9acd authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[LoRA] Continue optimizing MoE LoRA weight loading (#29322)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent cf348c8d
...@@ -28,12 +28,13 @@ def test_load_checkpoints( ...@@ -28,12 +28,13 @@ def test_load_checkpoints(
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: list[str] = [] expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES: for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_lst.extend(packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
if lora_name == "baichuan7B": if lora_name == "baichuan7B":
peft_helper = PEFTHelper.from_local_dir( peft_helper = PEFTHelper.from_local_dir(
baichuan_lora_files, max_position_embeddings=4096 baichuan_lora_files, max_position_embeddings=4096
...@@ -103,13 +104,13 @@ def test_lora_weights_mapping(baichuan_lora_files): ...@@ -103,13 +104,13 @@ def test_lora_weights_mapping(baichuan_lora_files):
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: list[str] = [] expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES: for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_lst.extend(packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
"model.": "language_model.model.", "model.": "language_model.model.",
......
...@@ -26,13 +26,13 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request): ...@@ -26,13 +26,13 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
expected_lora_modules: list[str] = [] expected_lora_lst: list[str] = []
for module in LLAMA_LORA_MODULES: for module in LLAMA_LORA_MODULES:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_lst.extend(packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
lora_path = get_adapter_absolute_path(lora_name) lora_path = get_adapter_absolute_path(lora_name)
# lora loading should work for either absolute path and huggingface id. # lora loading should work for either absolute path and huggingface id.
......
...@@ -60,7 +60,7 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -60,7 +60,7 @@ class BaseLayerWithLoRA(nn.Module):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError raise NotImplementedError
...@@ -153,7 +153,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -153,7 +153,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is ColumnParallelLinear or ( return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear type(source_layer) is MergedColumnParallelLinear
...@@ -272,7 +272,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -272,7 +272,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return ( return (
type(source_layer) is MergedColumnParallelLinear type(source_layer) is MergedColumnParallelLinear
...@@ -338,7 +338,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -338,7 +338,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1 return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1
...@@ -396,7 +396,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): ...@@ -396,7 +396,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3 return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
...@@ -434,7 +434,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ...@@ -434,7 +434,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
...@@ -480,7 +480,7 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo ...@@ -480,7 +480,7 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
...@@ -516,7 +516,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): ...@@ -516,7 +516,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
...@@ -565,7 +565,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): ...@@ -565,7 +565,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
......
...@@ -401,6 +401,61 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -401,6 +401,61 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w13_lora_b_stacked[1][lora_id][experts_id] self.w13_lora_b_stacked[1][lora_id][experts_id]
) )
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
if self.tp_size == 1:
return w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w13_lora_b[:, start_idx:end_idx, :]
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def reset_lora(self, index: int): def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0.""" """Resets the lora weights at index back to 0."""
for pos in range(self._w13_slices): for pos in range(self._w13_slices):
...@@ -411,6 +466,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -411,6 +466,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_b_stacked[0][index] = 0 self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0 self.adapter_enabled[index] = 0
#
def set_lora( def set_lora(
self, self,
index: int, index: int,
...@@ -418,69 +475,55 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -418,69 +475,55 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
lora_b: torch.Tensor | list[torch.Tensor], lora_b: torch.Tensor | list[torch.Tensor],
): ):
"""Overwrites lora tensors at index.""" """Overwrites lora tensors at index."""
# Make mypy happy
assert isinstance(lora_a, list) assert isinstance(lora_a, list)
assert isinstance(lora_b, list) assert isinstance(lora_b, list)
self.reset_lora(index) self.reset_lora(index)
self.adapter_enabled[index] = 1 self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1]
w3_lora_a = lora_a[eid * 3 + 2]
w1_lora_b = lora_b[eid * 3]
w2_lora_b = lora_b[eid * 3 + 1]
w3_lora_b = lora_b[eid * 3 + 2]
# Handle the case of adding LoRA to only a subset of experts
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
continue
if self.tp_size > 1:
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
w1_lora_b = w1_lora_b[start_idx:end_idx, :] num_experts = self.w13_lora_a_stacked[0].shape[1]
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
w2_lora_a = w2_lora_a[:, start_idx:end_idx] w1_lora_a, w2_lora_a, w3_lora_a = lora_a
w1_lora_b, w2_lora_b, w3_lora_b = lora_b
if self.fully_sharded: assert (
# Based on S-LoRA, we slice W1 and W3 A along the rank dim, num_experts
# and W2 B along the hidden_size dim. == w1_lora_a.shape[0]
w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0] == w2_lora_a.shape[0]
w13_start_idx = self.tp_rank * w13_shard_size == w3_lora_a.shape[0]
w13_end_idx = (self.tp_rank + 1) * w13_shard_size )
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :] slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)
w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0] slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
w2_start_idx = self.tp_rank * w2_shard_size slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :] sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
# w1 lora_a sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
self.w13_lora_a_stacked[0][ self.w13_lora_a_stacked[0][
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1] index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
].copy_(w1_lora_a, non_blocking=True) ].copy_(slliced_w1_lora_a, non_blocking=True)
# w3 lora_a
self.w13_lora_a_stacked[1][ self.w13_lora_a_stacked[1][
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1] index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
].copy_(w3_lora_a, non_blocking=True) ].copy_(slliced_w3_lora_a, non_blocking=True)
# w1 lora_b
self.w13_lora_b_stacked[0][ self.w13_lora_b_stacked[0][
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1] index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
].copy_(w1_lora_b, non_blocking=True) ].copy_(slliced_w1_lora_b, non_blocking=True)
# w3 lora_b
self.w13_lora_b_stacked[1][ self.w13_lora_b_stacked[1][
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1] index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
].copy_(w3_lora_b, non_blocking=True) ].copy_(slliced_w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[0][ self.w2_lora_a_stacked[0][
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1] index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
].copy_(w2_lora_a, non_blocking=True) ].copy_(sliced_w2_lora_a, non_blocking=True)
self.w2_lora_b_stacked[0][ self.w2_lora_b_stacked[0][
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1] index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(w2_lora_b, non_blocking=True) ].copy_(sliced_w2_lora_b, non_blocking=True)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs) return self.base_layer.forward(*args, **kwargs)
...@@ -506,12 +549,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -506,12 +549,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return type(source_layer) is FusedMoE and len(packed_modules_list) == 2 # source_layer is FusedMoE or SharedFusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2
class FusedMoE3DWithLoRA(FusedMoEWithLoRA): class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
...@@ -555,6 +598,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA): ...@@ -555,6 +598,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
model_config: PretrainedConfig | None = None, model_config: PretrainedConfig | None = None,
) -> None: ) -> None:
"""Initializes lora matrices.""" """Initializes lora matrices."""
assert isinstance(model_config, PretrainedConfig)
self._base_model = model_config.architectures[0]
self.max_loras = lora_config.max_loras self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras self.fully_sharded = lora_config.fully_sharded_loras
...@@ -565,20 +611,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA): ...@@ -565,20 +611,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
self._create_lora_a_weights(max_loras, lora_config) self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config) self._create_lora_b_weights(max_loras, lora_config)
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor: def _slice_w13_b(self, w13_lora_b: torch.Tensor):
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
if self.tp_size == 1: if self.tp_size == 1:
return w13_lora_b return w13_lora_b
...@@ -586,7 +619,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA): ...@@ -586,7 +619,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
shard_size = self.base_layer.intermediate_size_per_partition shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
if is_interleave: # HACK: Currently, only GPT-OSS is in interleaved order
if self._base_model == "GptOssForCausalLM":
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj) # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# in the interleaved order, and corresponding LoRA need to be processed. # in the interleaved order, and corresponding LoRA need to be processed.
w1_lora_b = w13_lora_b[:, ::2, :] w1_lora_b = w13_lora_b[:, ::2, :]
...@@ -606,28 +640,6 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA): ...@@ -606,28 +640,6 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1) return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def set_lora( def set_lora(
self, self,
index: int, index: int,
...@@ -658,7 +670,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA): ...@@ -658,7 +670,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
w2_lora_b = w2_lora_b.permute(1, 0, 2) w2_lora_b = w2_lora_b.permute(1, 0, 2)
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a) sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True) sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a) sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b) sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
...@@ -711,8 +723,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA): ...@@ -711,8 +723,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
# source_layer is FusedMoE or SharedFusedMoE
return type(source_layer) is FusedMoE and len(packed_modules_list) == 1 return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1
...@@ -197,7 +197,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -197,7 +197,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# Special handling for the LogitsProcessor. # Special handling for the LogitsProcessor.
return False return False
...@@ -53,7 +53,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -53,7 +53,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is ReplicatedLinear return type(source_layer) is ReplicatedLinear
......
...@@ -87,7 +87,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -87,7 +87,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is RowParallelLinear return type(source_layer) is RowParallelLinear
...@@ -164,7 +164,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -164,7 +164,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
......
...@@ -131,7 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -131,7 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is VocabParallelEmbedding return type(source_layer) is VocabParallelEmbedding
......
...@@ -152,6 +152,59 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -152,6 +152,59 @@ class PackedLoRALayerWeights(LoRALayerWeights):
) )
return obj return obj
@classmethod
def pack_moe(
cls, loras: GenericSequence[Optional["LoRALayerWeights"]], module_name: str
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
assert first_lora is not None
rank = first_lora.rank
lora_alpha = first_lora.lora_alpha
assert len(loras) % 3 == 0
w1_lora_a_lst = []
w2_lora_a_lst = []
w3_lora_a_lst = []
w1_lora_b_lst = []
w2_lora_b_lst = []
w3_lora_b_lst = []
# TODO: Consider the case where some experts don't have LoRA added.
for eid in range(len(loras) // 3):
w1_lora = loras[eid * 3]
w2_lora = loras[eid * 3 + 1]
w3_lora = loras[eid * 3 + 2]
assert w1_lora is not None
assert w2_lora is not None
assert w3_lora is not None
w1_lora_a_lst.append(w1_lora.lora_a)
w2_lora_a_lst.append(w2_lora.lora_a)
w3_lora_a_lst.append(w3_lora.lora_a)
w1_lora_b_lst.append(w1_lora.lora_b)
w2_lora_b_lst.append(w2_lora.lora_b)
w3_lora_b_lst.append(w3_lora.lora_b)
w1_lora_a = torch.stack(w1_lora_a_lst, dim=0) # (num_experts,rank,input_size)
w2_lora_a = torch.stack(w2_lora_a_lst, dim=0)
w3_lora_a = torch.stack(w3_lora_a_lst, dim=0)
w1_lora_b = torch.stack(w1_lora_b_lst, dim=0) # (num_experts,output_size,rank)
w2_lora_b = torch.stack(w2_lora_b_lst, dim=0)
w3_lora_b = torch.stack(w3_lora_b_lst, dim=0)
obj = cls(
module_name,
rank,
[lora_alpha, lora_alpha, lora_alpha],
[w1_lora_a, w2_lora_a, w3_lora_a],
[w1_lora_b, w2_lora_b, w3_lora_b],
)
return obj
def optimize(self) -> "PackedLoRALayerWeights": def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b.""" """Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)): for i in range(len(self.lora_b)):
......
...@@ -13,7 +13,7 @@ from torch import nn ...@@ -13,7 +13,7 @@ from torch import nn
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.punica_wrapper import get_punica_wrapper
...@@ -151,16 +151,13 @@ class LoRAModel: ...@@ -151,16 +151,13 @@ class LoRAModel:
if pin_memory: if pin_memory:
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory() loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, peft_helper.r, loras) return cls(lora_model_id, peft_helper.r, loras)
@classmethod @classmethod
def from_local_checkpoint( def from_local_checkpoint(
cls, cls,
lora_dir: str, lora_dir: str,
expected_lora_modules: list[str], expected_lora_modules: set[str],
peft_helper: PEFTHelper, peft_helper: PEFTHelper,
*, *,
lora_model_id: int | None = None, lora_model_id: int | None = None,
...@@ -190,10 +187,7 @@ class LoRAModel: ...@@ -190,10 +187,7 @@ class LoRAModel:
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
# new_embeddings_tensor_path = os.path.join(
# lora_dir, "new_embeddings.safetensors"
# )
# new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
tensors: dict[str, torch.Tensor] = {} tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[list[str] | str] = [] unexpected_modules: list[list[str] | str] = []
...@@ -201,18 +195,19 @@ class LoRAModel: ...@@ -201,18 +195,19 @@ class LoRAModel:
for lora_module in modules.keys(): # noqa for lora_module in modules.keys(): # noqa
if is_base_embeddding_weights(lora_module): if is_base_embeddding_weights(lora_module):
continue continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) # Handle PEFT file format where experts.base_layer is the
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj # gate_up_proj and experts is the down_proj
if "base_layer" in lora_module: if "base_layer" in lora_module:
continue continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Case for expert lora weights # Case for expert lora weights
if ".experts" in module_name: if ".experts" in module_name:
if not any( expert_idx = module_name.find(".experts")
module_name.endswith(ele) for ele in expected_lora_modules expert_suffix = module_name[expert_idx + 1 :]
): if expert_suffix not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)
elif module_name.split(".")[-1] not in expected_lora_modules:
elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)
if unexpected_modules: if unexpected_modules:
...@@ -358,9 +353,7 @@ class LoRAModelManager: ...@@ -358,9 +353,7 @@ class LoRAModelManager:
self.modules: dict[str, BaseLayerWithLoRA] = {} self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache. # Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None self._last_mapping: LoRAMapping | None = None
self._is_3d_moe_model = is_moe_model(self.model) and hasattr( self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
self.model, "is_3d_moe_weight"
)
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self self.model.lora_manager = self
...@@ -411,7 +404,7 @@ class LoRAModelManager: ...@@ -411,7 +404,7 @@ class LoRAModelManager:
continue continue
# Note (gnovack) - If MOE lora weights are not split into # Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here # num_experts chunks, we split them here
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor( if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor(
module_lora.lora_a module_lora.lora_a
): ):
# Handle PEFT file format where experts.base_layer is the # Handle PEFT file format where experts.base_layer is the
...@@ -679,6 +672,9 @@ class LoRAModelManager: ...@@ -679,6 +672,9 @@ class LoRAModelManager:
"cpu", "cpu",
) )
subloras.append(lora) subloras.append(lora)
if module.__class__.__name__ == "FusedMoEWithLoRA":
lora = PackedLoRALayerWeights.pack_moe(subloras, module_name)
else:
lora = PackedLoRALayerWeights.pack(subloras) lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora model.loras[module_name] = lora
return model return model
...@@ -739,6 +735,11 @@ class LoRAModelManager: ...@@ -739,6 +735,11 @@ class LoRAModelManager:
replaced_module_name = module_name.replace("model.", "") replaced_module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name): if lora_model.check_lora_name(module_name):
module_name = replaced_module_name module_name = replaced_module_name
if module_name.endswith(".experts"):
lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
replacement_loras, module_name
)
else:
lora_model.loras[module_name] = PackedLoRALayerWeights.pack( lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras replacement_loras
) )
...@@ -746,6 +747,9 @@ class LoRAModelManager: ...@@ -746,6 +747,9 @@ class LoRAModelManager:
for module in replaced_module: for module in replaced_module:
lora_model.loras.pop(module, None) lora_model.loras.pop(module, None)
for lora in lora_model.loras.values():
lora.optimize()
def _get_lora_layer_weights( def _get_lora_layer_weights(
self, lora_model: LoRAModel, module_name: str self, lora_model: LoRAModel, module_name: str
) -> LoRALayerWeights | None: ) -> LoRALayerWeights | None:
......
...@@ -170,16 +170,15 @@ def parse_fine_tuned_lora_name( ...@@ -170,16 +170,15 @@ def parse_fine_tuned_lora_name(
def is_base_embeddding_weights(name: str) -> bool: def is_base_embeddding_weights(name: str) -> bool:
# hardcoded subfixes for input & output embedding weights # hardcoded subfixes for input & output embedding weights
input_embedding_subfix = ".embed_tokens.base_layer.weight" embedding_suffixes = (
output_embedding_subfix = ".lm_head.base_layer.weight" ".embed_tokens.base_layer.weight",
".lm_head.base_layer.weight",
return name.endswith(input_embedding_subfix) or name.endswith(
output_embedding_subfix
) )
return name.endswith(embedding_suffixes)
def is_regex_target_modules( def is_regex_target_modules(
load_modules: str | list[str], expected_lora_modules: list[str] load_modules: str | list[str], expected_lora_modules: set[str]
) -> bool: ) -> bool:
""" """
PEFT supports passing `target_modules` in the form of regular expressions, PEFT supports passing `target_modules` in the form of regular expressions,
...@@ -195,8 +194,8 @@ def is_regex_target_modules( ...@@ -195,8 +194,8 @@ def is_regex_target_modules(
except re.error: except re.error:
return False return False
def is_subset(sub_list, full_list): def is_subset(sub_list, full_set):
return set(sub_list).issubset(set(full_list)) return set(sub_list).issubset(full_set)
# Similar to PEFT's processing logic, regex-related operations are only # Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`. # executed when the load_modules is a `str`.
...@@ -290,7 +289,7 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]: ...@@ -290,7 +289,7 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
# the expert indices are expanded based on the configured number # the expert indices are expanded based on the configured number
# of routed experts. # of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model) packed_modules_mapping = get_packed_modules_mapping(model)
if not hasattr(model, "is_3d_moe_weight"): if not model.is_3d_moe_weight:
# 3D MoE LoRA does not need `packed_modules_mapping` # 3D MoE LoRA does not need `packed_modules_mapping`
packed_modules_mapping["experts"] = [ packed_modules_mapping["experts"] = [
weight_name.rstrip(".") weight_name.rstrip(".")
......
...@@ -88,15 +88,15 @@ class WorkerLoRAManager: ...@@ -88,15 +88,15 @@ class WorkerLoRAManager:
try: try:
supported_lora_modules = self._adapter_manager.supported_lora_modules supported_lora_modules = self._adapter_manager.supported_lora_modules
packed_modules_mapping = self._adapter_manager.packed_modules_mapping packed_modules_mapping = self._adapter_manager.packed_modules_mapping
expected_lora_modules: list[str] = [] expected_lora_lst: list[str] = []
for module in supported_lora_modules: for module in supported_lora_modules:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_lst.extend(packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_lst.append(module)
if module == "experts": if module == "experts":
expected_lora_modules.append(module) expected_lora_lst.append(module)
expected_lora_modules = list(set(expected_lora_modules)) expected_lora_modules = set(expected_lora_lst)
lora_path = get_adapter_absolute_path(lora_request.lora_path) lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir( peft_helper = PEFTHelper.from_local_dir(
......
...@@ -336,6 +336,7 @@ class SupportsLoRA(Protocol): ...@@ -336,6 +336,7 @@ class SupportsLoRA(Protocol):
There is no need to redefine this flag if this class is in the There is no need to redefine this flag if this class is in the
MRO of your model class. MRO of your model class.
""" """
is_3d_moe_weight: ClassVar[bool] = False
# The `embedding_module` and `embedding_padding_modules` # The `embedding_module` and `embedding_padding_modules`
# are empty by default. # are empty by default.
embedding_modules: ClassVar[dict[str, str]] = {} embedding_modules: ClassVar[dict[str, str]] = {}
......
...@@ -401,6 +401,7 @@ class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts): ...@@ -401,6 +401,7 @@ class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
class Qwen3VLMoeForConditionalGeneration( class Qwen3VLMoeForConditionalGeneration(
Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts
): ):
is_3d_moe_weight: bool = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment