Unverified Commit bde38c11 authored by gnovack's avatar gnovack Committed by GitHub
Browse files

fix lora moe sharding when rank < max_lora_rank (#31994)


Signed-off-by: default avatargnovack <gnovack@amazon.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 707b240d
......@@ -95,7 +95,6 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
max_num_seqs=2,
max_num_batched_tokens=2048,
tensor_parallel_size=2,
......
......@@ -428,9 +428,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
shard_size = self.w13_lora_a_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
......@@ -465,11 +465,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
shard_size = self.w2_lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def reset_lora(self, index: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment