Unverified Commit 1073ba68 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[LoRA] Optimize 3D MoE logic (#29222)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent c309bb52
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -84,14 +86,17 @@ def test_gpt_oss_lora(gptoss20b_lora_files): ...@@ -84,14 +86,17 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
def test_gpt_oss_lora_tp2(gptoss20b_lora_files): @pytest.mark.parametrize("fully_sharded_loras", [False, True])
def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
max_model_len=1024, max_model_len=1024,
enable_lora=True, enable_lora=True,
max_loras=2, max_loras=2,
max_lora_rank=8, max_lora_rank=8,
max_num_seqs=16,
tensor_parallel_size=2, tensor_parallel_size=2,
fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False, cudagraph_specialize_lora=False,
), ),
......
...@@ -11,7 +11,7 @@ from vllm.lora.layers.column_parallel_linear import ( ...@@ -11,7 +11,7 @@ from vllm.lora.layers.column_parallel_linear import (
QKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
) )
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import ( from vllm.lora.layers.row_parallel_linear import (
...@@ -38,4 +38,5 @@ __all__ = [ ...@@ -38,4 +38,5 @@ __all__ = [
"ReplicatedLinearWithLoRA", "ReplicatedLinearWithLoRA",
"LoRAMapping", "LoRAMapping",
"FusedMoEWithLoRA", "FusedMoEWithLoRA",
"FusedMoE3DWithLoRA",
] ]
...@@ -42,8 +42,8 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -42,8 +42,8 @@ class BaseLayerWithLoRA(nn.Module):
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor, lora_b: torch.Tensor | list[torch.Tensor],
): ):
"""Overwrites lora tensors at index.""" """Overwrites lora tensors at index."""
... ...
......
...@@ -94,13 +94,15 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -94,13 +94,15 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor, lora_b: torch.Tensor | list[torch.Tensor],
): ):
# Except for QKVParallelLinearWithLoRA and # Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will # store weights in a tuple of size 1. These two layers will
# override this function. # override this function.
assert isinstance(lora_a, torch.Tensor)
assert isinstance(lora_b, torch.Tensor)
assert ( assert (
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1 len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
) )
......
...@@ -246,8 +246,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -246,8 +246,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor, lora_b: torch.Tensor | list[torch.Tensor],
): ):
self.reset_lora(index) self.reset_lora(index)
......
...@@ -42,7 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -42,7 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device self.device = base_layer.w2_weight.device
self.w13_slices = 2 self._w13_slices = 2
self._inject_lora_into_fused_moe() self._inject_lora_into_fused_moe()
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]: def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
...@@ -160,7 +160,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -160,7 +160,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
op_prefix="w13", op_prefix="w13",
num_loras=self.max_loras, num_loras=self.max_loras,
rank=max_lora_rank, rank=max_lora_rank,
num_slices=self.w13_slices, num_slices=self._w13_slices,
M=M, M=M,
layer=layer, layer=layer,
top_k=top_k, top_k=top_k,
...@@ -230,7 +230,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -230,7 +230,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
max_lora_rank = self.w2_lora_a_stacked.shape[-2] max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs( shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2", op_prefix="w2",
num_loras=self.max_loras, num_loras=self.max_loras,
...@@ -258,8 +258,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -258,8 +258,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.punica_wrapper.add_lora_fused_moe( self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3, intermediate_cache3,
intermediate_cache2, intermediate_cache2,
(self.w2_lora_a_stacked,), self.w2_lora_a_stacked,
(self.w2_lora_b_stacked,), self.w2_lora_b_stacked,
topk_weights, topk_weights,
sorted_token_ids_lora, sorted_token_ids_lora,
expert_ids_lora, expert_ids_lora,
...@@ -292,22 +292,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -292,22 +292,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.base_layer.quant_method, m_fused_moe_fn self.base_layer.quant_method, m_fused_moe_fn
) )
def create_lora_weights( def _create_lora_a_weights(
self, self,
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None, ):
) -> None: self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
"""Initializes lora matrices."""
assert self.w13_slices == 2
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self.w13_lora_a_stacked = tuple(
torch.zeros( torch.zeros(
( (
max_loras, max_loras,
...@@ -320,34 +310,37 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -320,34 +310,37 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
for _ in range(self.w13_slices) for _ in range(self._w13_slices)
) )
self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
self.w13_lora_b_stacked = tuple(
torch.zeros( torch.zeros(
( (
max_loras, max_loras,
self.base_layer.local_num_experts, self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
), ),
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) ),
for _ in range(self.w13_slices)
) )
self.w2_lora_a_stacked = torch.zeros( def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
torch.zeros(
( (
max_loras, max_loras,
self.base_layer.local_num_experts, self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition, self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
), ),
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
self.w2_lora_b_stacked = torch.zeros( for _ in range(self._w13_slices)
)
self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
torch.zeros(
( (
max_loras, max_loras,
self.base_layer.local_num_experts, self.base_layer.local_num_experts,
...@@ -358,10 +351,28 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -358,10 +351,28 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
), ),
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
),
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
) )
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights' # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights. # to create a dummy LoRA weights.
# TODO Optimize this section
self.lora_a_stacked = [] self.lora_a_stacked = []
self.lora_b_stacked = [] self.lora_b_stacked = []
for lora_id in range(max_loras): for lora_id in range(max_loras):
...@@ -370,36 +381,43 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -370,36 +381,43 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked.append( self.lora_a_stacked.append(
self.w13_lora_a_stacked[0][lora_id][experts_id] self.w13_lora_a_stacked[0][lora_id][experts_id]
) )
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append( self.lora_a_stacked.append(
self.w13_lora_a_stacked[1][lora_id][experts_id] self.w2_lora_a_stacked[0][lora_id][experts_id]
) )
self.lora_b_stacked.append( self.lora_b_stacked.append(
self.w13_lora_b_stacked[0][lora_id][experts_id] self.w13_lora_b_stacked[0][lora_id][experts_id]
) )
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id]) self.lora_b_stacked.append(
self.w2_lora_b_stacked[0][lora_id][experts_id]
)
self.lora_a_stacked.append(
self.w13_lora_a_stacked[1][lora_id][experts_id]
)
self.lora_b_stacked.append( self.lora_b_stacked.append(
self.w13_lora_b_stacked[1][lora_id][experts_id] self.w13_lora_b_stacked[1][lora_id][experts_id]
) )
def reset_lora(self, index: int): def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0.""" """Resets the lora weights at index back to 0."""
for pos in range(self.w13_slices): for pos in range(self._w13_slices):
self.w13_lora_a_stacked[pos][index] = 0 self.w13_lora_a_stacked[pos][index] = 0
self.w13_lora_b_stacked[pos][index] = 0 self.w13_lora_b_stacked[pos][index] = 0
self.w2_lora_a_stacked[index] = 0 self.w2_lora_a_stacked[0][index] = 0
self.w2_lora_b_stacked[index] = 0 self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0 self.adapter_enabled[index] = 0
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor, lora_b: torch.Tensor | list[torch.Tensor],
): ):
"""Overwrites lora tensors at index.""" """Overwrites lora tensors at index."""
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
self.reset_lora(index) self.reset_lora(index)
self.adapter_enabled[index] = 1 self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3): for eid in range(len(lora_a) // 3):
...@@ -432,7 +450,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -432,7 +450,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :] w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :] w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
w2_shard_size = self.w2_lora_b_stacked[index, eid].shape[0] w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
w2_start_idx = self.tp_rank * w2_shard_size w2_start_idx = self.tp_rank * w2_shard_size
w2_end_idx = (self.tp_rank + 1) * w2_shard_size w2_end_idx = (self.tp_rank + 1) * w2_shard_size
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :] w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
...@@ -454,14 +472,32 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -454,14 +472,32 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1] index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True) ].copy_(w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[ self.w2_lora_a_stacked[0][
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1] index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True) ].copy_(w2_lora_a, non_blocking=True)
self.w2_lora_b_stacked[ self.w2_lora_b_stacked[0][
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1] index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
].copy_(w2_lora_b, non_blocking=True) ].copy_(w2_lora_b, non_blocking=True)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
@property
def _shared_experts(self):
return self.base_layer._shared_experts
@property
def quant_method(self):
return self.base_layer.quant_method
@property
def is_internal_router(self) -> bool:
return self.base_layer.is_internal_router
@classmethod @classmethod
def can_replace_layer( def can_replace_layer(
cls, cls,
...@@ -472,22 +508,209 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -472,22 +508,209 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
) -> bool: ) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE # return type(source_layer) is FusedMoE
return isinstance(source_layer, FusedMoE)
def forward(self, *args, **kwargs): return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
return self.base_layer.forward(*args, **kwargs)
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs) class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
def __init__(self, base_layer):
super().__init__(base_layer)
self._w13_slices = 1
def _create_lora_b_weights(self, max_loras, lora_config):
self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition * 2,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self._w13_slices)
)
self.w2_lora_b_stacked: tuple[torch.Tensor] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
if self.tp_size == 1:
return w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
if is_interleave:
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# in the interleaved order, and corresponding LoRA need to be processed.
w1_lora_b = w13_lora_b[:, ::2, :]
w3_lora_b = w13_lora_b[:, 1::2, :]
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
1, 2
)
else:
slice_size = w13_lora_b.shape[1] // 2
w1_lora_b = w13_lora_b[:, :slice_size, :]
w3_lora_b = w13_lora_b[:, slice_size:, :]
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
# Make mypy happy
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
assert len(lora_a) == len(lora_b) == 2
self.reset_lora(index)
self.adapter_enabled[index] = 1
num_experts = self.w13_lora_a_stacked[0].shape[1]
w13_lora_a, w2_lora_a = lora_a
w13_lora_b, w2_lora_b = lora_b
# (num_experts,rank,input_size)
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
# (output_size,num_experts,rank)
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
# (num_experts,output_size,rank)
w13_lora_b = w13_lora_b.permute(1, 0, 2)
w2_lora_b = w2_lora_b.permute(1, 0, 2)
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
self.w13_lora_a_stacked[0][
index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
].copy_(sliced_w13_lora_a, non_blocking=True)
self.w2_lora_a_stacked[0][
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
].copy_(sliced_w2_lora_a, non_blocking=True)
self.w13_lora_b_stacked[0][
index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
].copy_(sliced_w13_lora_b, non_blocking=True)
self.w2_lora_b_stacked[0][
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(sliced_w2_lora_b, non_blocking=True)
@property @property
def _shared_experts(self): def w13_input_size(self):
return self.base_layer._shared_experts """
Full size
"""
return self.w13_lora_a_stacked[0].shape[-1]
@property @property
def quant_method(self): def w13_output_size(self):
return self.base_layer.quant_method """
Full size
"""
return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
@property @property
def is_internal_router(self) -> bool: def w2_input_size(self):
return self.base_layer.is_internal_router """
Full size
"""
return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size
@property
def w2_output_size(self):
"""
Full size
"""
return self.w2_lora_a_stacked[0].shape[-2]
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return type(source_layer) is FusedMoE and len(packed_modules_list) == 1
...@@ -128,9 +128,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -128,9 +128,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor, lora_b: torch.Tensor | list[torch.Tensor],
): ):
assert isinstance(lora_a, torch.Tensor)
assert isinstance(lora_b, torch.Tensor)
self.reset_lora(index) self.reset_lora(index)
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True lora_a, non_blocking=True
......
...@@ -77,12 +77,15 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -77,12 +77,15 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor, lora_b: torch.Tensor | list[torch.Tensor],
): ):
assert isinstance(lora_a, torch.Tensor)
assert isinstance(lora_b, torch.Tensor)
self.reset_lora(index) self.reset_lora(index)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major, # NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
# so we need transpose here # so we need transpose here
self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_( self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True lora_a.T, non_blocking=True
) )
......
...@@ -22,11 +22,13 @@ from vllm.lora.utils import ( ...@@ -22,11 +22,13 @@ from vllm.lora.utils import (
from_layer_logits_processor, from_layer_logits_processor,
get_supported_lora_modules, get_supported_lora_modules,
is_base_embeddding_weights, is_base_embeddding_weights,
is_moe_model,
is_regex_target_modules, is_regex_target_modules,
parse_fine_tuned_lora_name, parse_fine_tuned_lora_name,
process_packed_modules_mapping, process_packed_modules_mapping,
replace_submodule, replace_submodule,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.interfaces import is_pooling_model
...@@ -356,7 +358,11 @@ class LoRAModelManager: ...@@ -356,7 +358,11 @@ class LoRAModelManager:
self.modules: dict[str, BaseLayerWithLoRA] = {} self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache. # Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None self._last_mapping: LoRAMapping | None = None
self._is_3d_moe_model = is_moe_model(self.model) and hasattr(
self.model, "is_3d_moe_weight"
)
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self self.model.lora_manager = self
def __len__(self) -> int: def __len__(self) -> int:
...@@ -400,22 +406,36 @@ class LoRAModelManager: ...@@ -400,22 +406,36 @@ class LoRAModelManager:
self.lora_index_to_id[index] = lora_model.id self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items(): for module_name, module in self.modules.items():
module_lora = self._get_lora_layer_weights(lora_model, module_name) module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora: if not module_lora:
module.reset_lora(index)
continue
# Note (gnovack) - If MOE lora weights are not split into # Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here # num_experts chunks, we split them here
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor( if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
module_lora.lora_a module_lora.lora_a
): ):
# Handle FSDP file format where experts.base_layer is the # Handle PEFT file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj # gate_up_proj and experts is the down_proj
gate_up_proj_lora = self._get_lora_layer_weights( gate_up_proj_lora = self._get_lora_layer_weights(
lora_model, module_name + ".base_layer" lora_model, module_name + ".base_layer"
) )
assert gate_up_proj_lora is not None
assert module_lora is not None
down_proj_lora = module_lora down_proj_lora = module_lora
# FIXME Edge case where LoRA is not added to gate_up_proj
# or down_proj
assert gate_up_proj_lora is not None
assert down_proj_lora is not None
if self._is_3d_moe_model:
module_lora.lora_a = [
gate_up_proj_lora.lora_a,
down_proj_lora.lora_a,
]
module_lora.lora_b = [
gate_up_proj_lora.lora_b,
down_proj_lora.lora_b,
]
else:
# Some 3D MoE models haven't added the `is_3d_moe_weight`
# attribute yet, so fallback here
num_experts = module_lora.lora_a.shape[0] // module_lora.rank num_experts = module_lora.lora_a.shape[0] // module_lora.rank
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
...@@ -444,14 +464,12 @@ class LoRAModelManager: ...@@ -444,14 +464,12 @@ class LoRAModelManager:
module_lora.lora_a = lora_a module_lora.lora_a = lora_a
module_lora.lora_b = lora_b module_lora.lora_b = lora_b
module.set_lora( module.set_lora(
index, index,
module_lora.lora_a, module_lora.lora_a,
module_lora.lora_b, module_lora.lora_b,
) )
else:
module.reset_lora(index)
return True return True
def _deactivate_adapter(self, lora_id: int): def _deactivate_adapter(self, lora_id: int):
...@@ -512,6 +530,13 @@ class LoRAModelManager: ...@@ -512,6 +530,13 @@ class LoRAModelManager:
continue continue
parts = module_name.split(".")[-1] parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, []) packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
if isinstance(module, FusedMoE):
# packed_moduled_lst is used here to just determine whether to
# instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
# difference between these two LoRA layers is whether the
# LoRA weights of w1 and w3 have already been fused on disk.
packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
new_module = replace_submodule( new_module = replace_submodule(
self.model, self.model,
module_name, module_name,
...@@ -560,6 +585,7 @@ class LoRAModelManager: ...@@ -560,6 +585,7 @@ class LoRAModelManager:
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference. # All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(self.punica_wrapper) new_module.set_mapping(self.punica_wrapper)
pass
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), ( assert isinstance(module, BaseLayerWithLoRA), (
...@@ -605,6 +631,30 @@ class LoRAModelManager: ...@@ -605,6 +631,30 @@ class LoRAModelManager:
module.lora_a_stacked[0].dtype, module.lora_a_stacked[0].dtype,
"cpu", "cpu",
) )
model.loras[module_name] = lora
elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
# Case for 3D moe model
# w2
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.w2_input_size,
module.w2_output_size,
rank * module.w2_lora_a_stacked[0].shape[1], # rank*num_experts
module.w2_lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name] = lora
# w13
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.w13_input_size,
module.w13_output_size,
rank
* module.w13_lora_a_stacked[0].shape[1], # rank*num_experts
module.w13_lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name + ".base_layer"] = lora
else: else:
lora = LoRALayerWeights.create_dummy_lora_weights( lora = LoRALayerWeights.create_dummy_lora_weights(
module_name, module_name,
...@@ -614,6 +664,7 @@ class LoRAModelManager: ...@@ -614,6 +664,7 @@ class LoRAModelManager:
module.lora_a_stacked[0].dtype, module.lora_a_stacked[0].dtype,
"cpu", "cpu",
) )
model.loras[module_name] = lora
else: else:
parts = module_name.split(".") parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]] replacements = self.packed_modules_mapping[parts[-1]]
......
...@@ -23,6 +23,7 @@ from vllm.lora.layers import ( ...@@ -23,6 +23,7 @@ from vllm.lora.layers import (
BaseLayerWithLoRA, BaseLayerWithLoRA,
ColumnParallelLinearWithLoRA, ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
FusedMoE3DWithLoRA,
FusedMoEWithLoRA, FusedMoEWithLoRA,
LogitsProcessorWithLoRA, LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
...@@ -62,6 +63,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = { ...@@ -62,6 +63,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedQKVParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA, FusedMoEWithLoRA,
FusedMoE3DWithLoRA,
} }
...@@ -288,9 +290,11 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]: ...@@ -288,9 +290,11 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
# the expert indices are expanded based on the configured number # the expert indices are expanded based on the configured number
# of routed experts. # of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model) packed_modules_mapping = get_packed_modules_mapping(model)
if not hasattr(model, "is_3d_moe_weight"):
# 3D MoE LoRA does not need `packed_modules_mapping`
packed_modules_mapping["experts"] = [ packed_modules_mapping["experts"] = [
weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping weight_name.rstrip(".")
for _, weight_name, _, _ in moe_packed_mapping
] ]
return packed_modules_mapping return packed_modules_mapping
......
...@@ -656,6 +656,7 @@ class GptOssModel(nn.Module): ...@@ -656,6 +656,7 @@ class GptOssModel(nn.Module):
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
is_3d_moe_weight: bool = True
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment