Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7320ca39
Unverified
Commit
7320ca39
authored
Feb 01, 2026
by
Runkai Tao
Committed by
GitHub
Feb 02, 2026
Browse files
Add unpermute-aware fused MoE LoRA path (#32655)
Signed-off-by:
Runkai Tao
<
rt572@physics.rutgers.edu
>
parent
cf0a99f8
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
474 additions
and
119 deletions
+474
-119
benchmarks/kernels/benchmark_lora.py
benchmarks/kernels/benchmark_lora.py
+2
-0
tests/lora/test_fused_moe_lora_kernel.py
tests/lora/test_fused_moe_lora_kernel.py
+184
-0
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+25
-4
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
+201
-70
vllm/lora/punica_wrapper/punica_base.py
vllm/lora/punica_wrapper/punica_base.py
+4
-3
vllm/lora/punica_wrapper/punica_gpu.py
vllm/lora/punica_wrapper/punica_gpu.py
+58
-42
No files found.
benchmarks/kernels/benchmark_lora.py
View file @
7320ca39
...
@@ -842,6 +842,7 @@ class BenchmarkTensors:
...
@@ -842,6 +842,7 @@ class BenchmarkTensors:
"sorted_token_ids"
:
sorted_token_ids
,
"sorted_token_ids"
:
sorted_token_ids
,
"expert_ids"
:
expert_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"token_lora_mapping"
:
self
.
lora_kernel_meta
.
token_lora_mapping
,
"top_k_num"
:
ctx
.
top_k_num
,
"top_k_num"
:
ctx
.
top_k_num
,
"device"
:
self
.
input
.
device
,
"device"
:
self
.
input
.
device
,
"N"
:
lora_rank
,
"N"
:
lora_rank
,
...
@@ -915,6 +916,7 @@ class BenchmarkTensors:
...
@@ -915,6 +916,7 @@ class BenchmarkTensors:
"sorted_token_ids"
:
sorted_token_ids
,
"sorted_token_ids"
:
sorted_token_ids
,
"expert_ids"
:
expert_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"token_lora_mapping"
:
self
.
lora_kernel_meta
.
token_lora_mapping
,
"top_k_num"
:
ctx
.
top_k_num
,
"top_k_num"
:
ctx
.
top_k_num
,
"device"
:
self
.
input
.
device
,
"device"
:
self
.
input
.
device
,
"N"
:
lora_rank
,
"N"
:
lora_rank
,
...
...
tests/lora/test_fused_moe_lora_kernel.py
View file @
7320ca39
...
@@ -190,6 +190,7 @@ def use_fused_moe_lora_kernel(
...
@@ -190,6 +190,7 @@ def use_fused_moe_lora_kernel(
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
token_lora_mapping
,
max_lora_rank
,
max_lora_rank
,
top_k_num
,
top_k_num
,
lora_ids
,
lora_ids
,
...
@@ -333,6 +334,189 @@ def test_fused_moe_lora_kernel(
...
@@ -333,6 +334,189 @@ def test_fused_moe_lora_kernel(
torch
.
testing
.
assert_close
(
output
,
output2
,
atol
=
1e-1
,
rtol
=
1e-1
)
torch
.
testing
.
assert_close
(
output
,
output2
,
atol
=
1e-1
,
rtol
=
1e-1
)
def
use_fused_moe_lora_kernel_naive
(
topk_ids
,
topk_weights
,
token_lora_mapping
,
max_lora_rank
,
top_k_num
,
lora_a_stacked
,
lora_b_stacked
,
hidden_states
,
output
,
max_loras
,
block_size
,
fully_sharded
=
False
,
offset
=
0
,
):
"""
Test helper for naive_block_assignment path.
Skips moe_lora_align_block_size and uses flattened topk_ids as expert_ids.
"""
config
=
{
"BLOCK_SIZE_M"
:
block_size
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"NUM_WARPS"
:
4
,
"NUM_STAGES"
:
3
,
"SPLIT_K"
:
1
,
}
mul_routed_weight
=
False
# In naive mode:
# - expert_ids = topk_ids.view(-1), shape: (num_tokens * top_k,)
# - sorted_token_ids = None
# - num_tokens_post_padded = None
expert_ids
=
topk_ids
.
reshape
(
-
1
)
sorted_token_ids
=
None
num_tokens_post_padded
=
None
adapter_enabled
=
torch
.
ones
(
max_loras
+
1
,
dtype
=
torch
.
int32
)
lora_ids
=
torch
.
arange
(
max_loras
+
2
,
dtype
=
torch
.
int32
)
fused_moe_lora
(
output
,
hidden_states
,
lora_a_stacked
,
lora_b_stacked
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
token_lora_mapping
,
max_lora_rank
,
top_k_num
,
lora_ids
,
adapter_enabled
,
config
[
"BLOCK_SIZE_M"
],
config
[
"BLOCK_SIZE_N"
],
config
[
"BLOCK_SIZE_K"
],
config
[
"GROUP_SIZE_M"
],
config
[
"NUM_WARPS"
],
config
[
"NUM_STAGES"
],
config
[
"SPLIT_K"
],
config
[
"BLOCK_SIZE_M"
],
config
[
"BLOCK_SIZE_N"
],
config
[
"BLOCK_SIZE_K"
],
config
[
"GROUP_SIZE_M"
],
config
[
"NUM_WARPS"
],
config
[
"NUM_STAGES"
],
config
[
"SPLIT_K"
],
mul_routed_weight
=
mul_routed_weight
,
fully_sharded
=
fully_sharded
,
offset
=
offset
,
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k_num"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"max_loras"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
1408
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"max_lora_rank"
,
[
16
,
32
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEED
)
def
test_fused_moe_lora_kernel_naive_block_assignment
(
num_tokens
,
top_k_num
,
num_experts
,
max_loras
,
N
,
K
,
max_lora_rank
,
block_size
,
dtype
,
device
,
seed
,
):
"""
Test the naive_block_assignment path of the fused_moe_lora kernel.
This path is triggered when batch_size * top_k is much smaller than
num_experts * max_loras, and skips the moe_lora_align_block_size kernel.
"""
torch
.
set_default_device
(
device
)
set_random_seed
(
seed
)
# Verify this configuration would trigger naive_block_assignment
# (num_tokens * top_k * SPARSITY_FACTOR <= num_experts * max_loras)
SPARSITY_FACTOR
=
8
assert
num_tokens
*
top_k_num
*
SPARSITY_FACTOR
<=
num_experts
*
max_loras
,
(
f
"Test configuration doesn't meet naive_block_assignment condition: "
f
"
{
num_tokens
}
*
{
top_k_num
}
*
{
SPARSITY_FACTOR
}
>
{
num_experts
}
*
{
max_loras
}
"
)
# the number of randomly generated sentences.
num_sequences
=
min
(
num_tokens
,
4
)
# generate data
topk_ids
,
topk_weights
,
token_lora_mapping
=
sample_data
(
num_tokens
,
num_sequences
,
max_loras
,
num_experts
,
top_k_num
)
# init lora weights
lora_a_stacked
=
[
torch
.
rand
(
(
max_loras
,
num_experts
,
max_lora_rank
,
K
,
),
dtype
=
dtype
,
)
]
lora_b_stacked
=
[
torch
.
rand
(
(
max_loras
,
num_experts
,
N
,
max_lora_rank
,
),
dtype
=
dtype
,
)
]
hidden_states
=
torch
.
rand
(
(
num_tokens
,
K
,
),
dtype
=
dtype
,
)
# fused_moe_lora_kernel output (naive path)
output
=
torch
.
zeros
((
num_tokens
,
top_k_num
,
N
),
dtype
=
dtype
)
use_fused_moe_lora_kernel_naive
(
topk_ids
,
topk_weights
,
token_lora_mapping
,
max_lora_rank
,
top_k_num
,
lora_a_stacked
,
lora_b_stacked
,
hidden_states
,
output
,
max_loras
,
block_size
,
)
# pytorch reference output
output_ref
=
use_torch
(
hidden_states
,
token_lora_mapping
,
topk_ids
,
lora_a_stacked
,
lora_b_stacked
,
top_k_num
,
)
torch
.
testing
.
assert_close
(
output
,
output_ref
,
atol
=
1e-1
,
rtol
=
1e-1
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
100
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
100
])
@
pytest
.
mark
.
parametrize
(
"top_k_num"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"top_k_num"
,
[
6
])
...
...
vllm/lora/layers/fused_moe.py
View file @
7320ca39
...
@@ -190,8 +190,18 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -190,8 +190,18 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
config_dtype
=
config_dtype
,
config_dtype
=
config_dtype
,
)
)
# SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k
# activates only a small fraction of total experts * loras.
SPARSITY_FACTOR
=
8
naive_block_assignment
=
(
expert_map
is
None
and
num_tokens
*
top_k
*
SPARSITY_FACTOR
<=
self
.
base_layer
.
local_num_experts
*
self
.
max_loras
)
# get the block size of m from customized config or default config
# get the block size of m from customized config or default config
(
(
token_lora_mapping
,
sorted_token_ids_lora
,
sorted_token_ids_lora
,
expert_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
,
num_tokens_post_padded_lora
,
...
@@ -203,6 +213,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -203,6 +213,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
max_loras
,
self
.
max_loras
,
self
.
adapter_enabled
,
self
.
adapter_enabled
,
expert_map
,
expert_map
,
naive_block_assignment
,
)
)
moe_state_dict
[
"sorted_token_ids_lora"
]
=
sorted_token_ids_lora
moe_state_dict
[
"sorted_token_ids_lora"
]
=
sorted_token_ids_lora
...
@@ -210,9 +221,13 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -210,9 +221,13 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
moe_state_dict
[
"num_tokens_post_padded_lora"
]
=
(
moe_state_dict
[
"num_tokens_post_padded_lora"
]
=
(
num_tokens_post_padded_lora
num_tokens_post_padded_lora
)
)
moe_state_dict
[
"token_lora_mapping"
]
=
token_lora_mapping
if
sorted_token_ids_lora
is
not
None
:
expert_ids_lora
=
expert_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
expert_ids_lora
=
expert_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
sorted_token_ids_lora
=
sorted_token_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
sorted_token_ids_lora
=
sorted_token_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
#
#
self
.
punica_wrapper
.
add_lora_fused_moe
(
self
.
punica_wrapper
.
add_lora_fused_moe
(
...
@@ -230,6 +245,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -230,6 +245,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
expand_config
,
## pass the expand config
expand_config
,
## pass the expand config
self
.
adapter_enabled
,
self
.
adapter_enabled
,
fully_sharded
=
self
.
fully_sharded
,
fully_sharded
=
self
.
fully_sharded
,
token_lora_mapping
=
token_lora_mapping
,
)
)
result
=
func
(
*
args
,
**
kwargs
)
result
=
func
(
*
args
,
**
kwargs
)
...
@@ -270,9 +286,13 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -270,9 +286,13 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora
=
moe_state_dict
[
num_tokens_post_padded_lora
=
moe_state_dict
[
"num_tokens_post_padded_lora"
"num_tokens_post_padded_lora"
]
]
token_lora_mapping
=
moe_state_dict
.
get
(
"token_lora_mapping"
)
if
sorted_token_ids_lora
is
not
None
:
expert_ids_lora
=
expert_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
expert_ids_lora
=
expert_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
sorted_token_ids_lora
=
sorted_token_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
sorted_token_ids_lora
=
sorted_token_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
intermediate_cache2
=
moe_state_dict
[
"intermediate_cache2"
]
intermediate_cache2
=
moe_state_dict
[
"intermediate_cache2"
]
intermediate_cache3
=
args
[
0
]
intermediate_cache3
=
args
[
0
]
...
@@ -295,6 +315,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -295,6 +315,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
True
,
True
,
fully_sharded
=
self
.
fully_sharded
,
fully_sharded
=
self
.
fully_sharded
,
offset
=
shard_size_w2
*
self
.
tp_rank
if
self
.
fully_sharded
else
0
,
offset
=
shard_size_w2
*
self
.
tp_rank
if
self
.
fully_sharded
else
0
,
token_lora_mapping
=
token_lora_mapping
,
)
)
result
=
func
(
*
args
,
**
kwargs
)
result
=
func
(
*
args
,
**
kwargs
)
...
...
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
View file @
7320ca39
...
@@ -12,6 +12,64 @@ from vllm.utils.torch_utils import direct_register_custom_op
...
@@ -12,6 +12,64 @@ from vllm.utils.torch_utils import direct_register_custom_op
from
.utils
import
supports_pdl
from
.utils
import
supports_pdl
@
triton
.
jit
def
_get_lora_id
(
lora_ids
,
token_lora_mapping_ptr
,
lora_idx
,
pid_m
,
top_k_num
,
naive_block_assignment
:
tl
.
constexpr
,
):
"""Returns lora_id"""
if
naive_block_assignment
:
token_idx
=
pid_m
//
top_k_num
return
tl
.
load
(
token_lora_mapping_ptr
+
token_idx
)
else
:
return
tl
.
load
(
lora_ids
+
lora_idx
)
@
triton
.
jit
def
_get_expert_id
(
expert_ids_ptr
,
lora_id
,
pid_m
,
stride_el
,
max_loras
,
naive_block_assignment
:
tl
.
constexpr
,
):
"""Returns expert_id"""
if
naive_block_assignment
:
return
tl
.
load
(
expert_ids_ptr
+
pid_m
)
else
:
ind
=
lora_id
*
stride_el
+
pid_m
return
tl
.
load
(
expert_ids_ptr
+
ind
,
ind
<
max_loras
*
stride_el
,
-
1
)
@
triton
.
jit
def
_get_token_offs
(
sorted_token_ids_ptr
,
lora_id
,
pid_m
,
offs
,
stride_tl
,
max_loras
,
num_valid_tokens
,
naive_block_assignment
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
):
"""Returns token offsets"""
if
naive_block_assignment
:
return
tl
.
where
(
offs
==
0
,
pid_m
,
num_valid_tokens
)
else
:
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
offs
token_ind
=
stride_tl
*
lora_id
+
offs_token_id
return
tl
.
load
(
sorted_token_ids_ptr
+
token_ind
,
token_ind
<
max_loras
*
stride_tl
,
0
)
_LORA_PTR_DICT
:
dict
[
tuple
[
int
,
...],
torch
.
tensor
]
=
{}
_LORA_PTR_DICT
:
dict
[
tuple
[
int
,
...],
torch
.
tensor
]
=
{}
...
@@ -36,6 +94,25 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
...
@@ -36,6 +94,25 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
return
_LORA_PTR_DICT
.
get
(
key
)
return
_LORA_PTR_DICT
.
get
(
key
)
def
_adjust_kernel_inputs
(
max_loras
:
int
,
sorted_token_ids
:
torch
.
Tensor
|
None
,
expert_ids
:
torch
.
Tensor
,
):
"""
helper function to adjust kernel inputs when sorted_token_ids is None
"""
if
sorted_token_ids
is
None
:
stride_tl
=
0
stride_el
=
0
grid_lora_dim
=
1
else
:
stride_tl
=
sorted_token_ids
.
stride
(
0
)
stride_el
=
expert_ids
.
stride
(
0
)
grid_lora_dim
=
max_loras
+
1
return
grid_lora_dim
,
stride_tl
,
stride_el
@
triton
.
jit
(
@
triton
.
jit
(
do_not_specialize
=
[
do_not_specialize
=
[
"num_valid_tokens"
,
"num_valid_tokens"
,
...
@@ -54,12 +131,14 @@ def _fused_moe_lora_kernel(
...
@@ -54,12 +131,14 @@ def _fused_moe_lora_kernel(
sorted_token_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
num_tokens_post_padded_ptr
,
token_lora_mapping_ptr
,
# Matrix dimensions
# Matrix dimensions
N
,
N
,
K
,
K
,
EM
,
EM
,
num_valid_tokens
,
num_valid_tokens
,
num_experts
,
num_experts
,
top_k_num
,
lora_ids
,
lora_ids
,
adapter_enabled
,
adapter_enabled
,
max_loras
,
# <<< PR2: rename, used for masks when grid axis-2 != max_loras
max_loras
,
# <<< PR2: rename, used for masks when grid axis-2 != max_loras
...
@@ -82,7 +161,11 @@ def _fused_moe_lora_kernel(
...
@@ -82,7 +161,11 @@ def _fused_moe_lora_kernel(
# Meta-parameters
# Meta-parameters
num_slice_a
:
tl
.
constexpr
,
num_slice_a
:
tl
.
constexpr
,
num_slice_c
:
tl
.
constexpr
,
num_slice_c
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
# top_k_num or 1 depending on input token
# is expanded by top_k or not
token_mapping_factor
:
tl
.
constexpr
,
# whether use naive block assignment
naive_block_assignment
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
ADD_INPUTS
:
tl
.
constexpr
,
ADD_INPUTS
:
tl
.
constexpr
,
USE_B_L2_CACHE
:
tl
.
constexpr
,
# new, enable .ca load for B
USE_B_L2_CACHE
:
tl
.
constexpr
,
# new, enable .ca load for B
...
@@ -97,26 +180,10 @@ def _fused_moe_lora_kernel(
...
@@ -97,26 +180,10 @@ def _fused_moe_lora_kernel(
):
):
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
slice_id
=
tl
.
program_id
(
axis
=
1
)
slice_id
=
tl
.
program_id
(
axis
=
1
)
lora_idx
=
tl
.
program_id
(
axis
=
2
)
lora_id
=
tl
.
load
(
lora_ids
+
lora_idx
)
if
lora_id
==
-
1
:
# Early exit for the no-lora case.
return
moe_enabled
=
tl
.
load
(
adapter_enabled
+
lora_id
)
if
moe_enabled
==
0
:
# Early exit for the no moe lora case.
return
# The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel.
# This guard ensures we don't access sorted_token_ids / expert_ids /
# num_tokens_post_padded beyond their allocated bounds if an invalid
# lora_id somehow appears. Although the caller should pass correct
# max_loras, defensive programming prevents accidental out-of-bounds.
if
lora_id
>=
max_loras
:
return
grid_k
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
*
SPLIT_K
)
grid_k
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
*
SPLIT_K
)
# calculate pid_m,pid_n
# calculate pid_m,pid_n
lora_idx
=
tl
.
program_id
(
axis
=
2
)
pid_sk
=
pid
%
SPLIT_K
pid_sk
=
pid
%
SPLIT_K
pid_m_n
=
pid
//
SPLIT_K
pid_m_n
=
pid
//
SPLIT_K
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
...
@@ -129,14 +196,55 @@ def _fused_moe_lora_kernel(
...
@@ -129,14 +196,55 @@ def _fused_moe_lora_kernel(
pid_m
=
first_pid_m
+
((
pid_m_n
%
num_pid_in_group
)
%
group_size_m
)
pid_m
=
first_pid_m
+
((
pid_m_n
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid_m_n
%
num_pid_in_group
)
//
group_size_m
pid_n
=
(
pid_m_n
%
num_pid_in_group
)
//
group_size_m
offs
=
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int64
)
# Get lora_id
lora_id
=
_get_lora_id
(
lora_ids
,
token_lora_mapping_ptr
,
lora_idx
,
pid_m
,
top_k_num
,
naive_block_assignment
,
)
if
lora_id
==
-
1
:
return
moe_enabled
=
tl
.
load
(
adapter_enabled
+
lora_id
)
if
moe_enabled
==
0
:
return
if
lora_id
>=
max_loras
:
return
# Non-naive only: check num_tokens_post_padded
if
not
naive_block_assignment
:
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
+
lora_id
)
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
+
lora_id
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
return
# get the expert_id to process curr shard
ind
=
lora_id
*
stride_el
+
pid_m
# Get expert_id
expert_id
=
tl
.
load
(
expert_ids_ptr
+
ind
,
ind
<
max_loras
*
stride_el
,
-
1
)
expert_id
=
_get_expert_id
(
expert_ids_ptr
,
lora_id
,
pid_m
,
stride_el
,
max_loras
,
naive_block_assignment
,
)
if
expert_id
==
-
1
:
if
expert_id
==
-
1
:
return
return
# Get token offsets
offs_token
=
_get_token_offs
(
sorted_token_ids_ptr
,
lora_id
,
pid_m
,
offs
,
stride_tl
,
max_loras
,
num_valid_tokens
,
naive_block_assignment
,
BLOCK_SIZE_M
,
)
# get a_ptr,b_ptr,c_ptr
# get a_ptr,b_ptr,c_ptr
cur_a_ptr
=
a_ptr
+
(
slice_id
%
num_slice_a
)
*
slice_a_size
cur_a_ptr
=
a_ptr
+
(
slice_id
%
num_slice_a
)
*
slice_a_size
cur_b_ptr
=
tl
.
load
(
b_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
c_ptr
.
dtype
.
element_ty
))
cur_b_ptr
=
tl
.
load
(
b_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
c_ptr
.
dtype
.
element_ty
))
...
@@ -145,19 +253,12 @@ def _fused_moe_lora_kernel(
...
@@ -145,19 +253,12 @@ def _fused_moe_lora_kernel(
# remove modulo wrap-around
# remove modulo wrap-around
offs_bn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int32
)
offs_bn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int32
)
offs_k
=
pid_sk
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_k
=
pid_sk
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int32
)
token_ind
=
stride_tl
*
lora_id
+
offs_token_id
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
token_ind
,
mask
=
token_ind
<
max_loras
*
stride_tl
,
other
=
num_valid_tokens
,
)
token_mask
=
offs_token
<
num_valid_tokens
token_mask
=
offs_token
<
num_valid_tokens
# get a_ptrs,b_ptrs
# get a_ptrs,b_ptrs
a_ptrs
=
cur_a_ptr
+
(
a_ptrs
=
cur_a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
offs_token
[:,
None
]
//
token_mapping_factor
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
)
b_ptrs
=
(
b_ptrs
=
(
...
@@ -230,9 +331,10 @@ def _fused_moe_lora_shrink(
...
@@ -230,9 +331,10 @@ def _fused_moe_lora_shrink(
torch
.
Tensor
torch
.
Tensor
],
# [(max_loras, num_experts, max_lora_rank, K,),...]
],
# [(max_loras, num_experts, max_lora_rank, K,),...]
topk_weights
:
torch
.
Tensor
,
# (num_tokens, top_k_num)
topk_weights
:
torch
.
Tensor
,
# (num_tokens, top_k_num)
sorted_token_ids
:
torch
.
Tensor
,
# (max_loras, _)
sorted_token_ids
:
torch
.
Tensor
|
None
,
# (max_loras, _)
expert_ids
:
torch
.
Tensor
,
# (max_loras, _ ,)
expert_ids
:
torch
.
Tensor
,
# (max_loras, _ ,) or (num_tokens * top_k,)
num_tokens_post_padded
:
torch
.
Tensor
,
# (max_loras, )
num_tokens_post_padded
:
torch
.
Tensor
|
None
,
# (max_loras, )
token_lora_mapping
:
torch
.
Tensor
,
top_k_num
:
int
,
top_k_num
:
int
,
lora_ids
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
...
@@ -270,13 +372,15 @@ def _fused_moe_lora_shrink(
...
@@ -270,13 +372,15 @@ def _fused_moe_lora_shrink(
b_ptr
=
_get_ptr
(
lora_a_stacked
,
device
)
b_ptr
=
_get_ptr
(
lora_a_stacked
,
device
)
grid_lora_dim
,
stride_tl
,
stride_el
=
_adjust_kernel_inputs
(
w1_lora_a_stacked
.
shape
[
0
],
sorted_token_ids
,
expert_ids
)
grid
=
lambda
META
:
(
grid
=
lambda
META
:
(
split_k
split_k
*
triton
.
cdiv
(
EM
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
EM
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
len
(
lora_a_stacked
),
len
(
lora_a_stacked
),
## max_loras + 1 to handle the no-lora case (lora_id == -1)
grid_lora_dim
,
lora_a_stacked
[
0
].
shape
[
0
]
+
1
,
)
)
_fused_moe_lora_kernel
[
grid
](
_fused_moe_lora_kernel
[
grid
](
qcurr_hidden_states
,
qcurr_hidden_states
,
...
@@ -286,11 +390,13 @@ def _fused_moe_lora_shrink(
...
@@ -286,11 +390,13 @@ def _fused_moe_lora_shrink(
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
token_lora_mapping
,
N
,
N
,
K
,
K
,
EM
,
EM
,
num_tokens
,
num_tokens
,
num_experts
,
num_experts
,
top_k_num
,
lora_ids
,
lora_ids
,
adapter_enabled
,
adapter_enabled
,
lora_a_stacked
[
0
].
shape
[
0
],
lora_a_stacked
[
0
].
shape
[
0
],
...
@@ -302,13 +408,14 @@ def _fused_moe_lora_shrink(
...
@@ -302,13 +408,14 @@ def _fused_moe_lora_shrink(
w1_lora_a_stacked
.
stride
(
2
),
w1_lora_a_stacked
.
stride
(
2
),
a_intermediate_cache1
.
stride
(
2
),
a_intermediate_cache1
.
stride
(
2
),
a_intermediate_cache1
.
stride
(
3
),
a_intermediate_cache1
.
stride
(
3
),
s
orted_token_ids
.
stride
(
0
)
,
s
tride_tl
,
expert_ids
.
stride
(
0
)
,
stride
_el
,
slice_a_size
=
qcurr_hidden_states
.
numel
(),
slice_a_size
=
qcurr_hidden_states
.
numel
(),
slice_c_size
=
a_intermediate_cache1
.
numel
()
//
num_slices
,
slice_c_size
=
a_intermediate_cache1
.
numel
()
//
num_slices
,
num_slice_a
=
1
,
num_slice_a
=
1
,
num_slice_c
=
num_slices
,
num_slice_c
=
num_slices
,
top_k
=
1
if
mul_routed_weight
else
top_k_num
,
token_mapping_factor
=
1
if
mul_routed_weight
else
top_k_num
,
naive_block_assignment
=
sorted_token_ids
is
None
,
MUL_ROUTED_WEIGHT
=
False
,
MUL_ROUTED_WEIGHT
=
False
,
ADD_INPUTS
=
False
,
ADD_INPUTS
=
False
,
USE_B_L2_CACHE
=
True
,
# new
USE_B_L2_CACHE
=
True
,
# new
...
@@ -325,9 +432,10 @@ def _fused_moe_lora_expand(
...
@@ -325,9 +432,10 @@ def _fused_moe_lora_expand(
torch
.
Tensor
torch
.
Tensor
],
# [(max_loras, num_experts, max_lora_rank, K,),...]
],
# [(max_loras, num_experts, max_lora_rank, K,),...]
topk_weights
:
torch
.
Tensor
,
# (num_tokens, top_k_num)
topk_weights
:
torch
.
Tensor
,
# (num_tokens, top_k_num)
sorted_token_ids
:
torch
.
Tensor
,
# (max_loras, _)
sorted_token_ids
:
torch
.
Tensor
|
None
,
# (max_loras, _)
expert_ids
:
torch
.
Tensor
,
# (max_loras, _ ,)
expert_ids
:
torch
.
Tensor
,
# (max_loras, _ ,) or (num_tokens * top_k,)
num_tokens_post_padded
:
torch
.
Tensor
,
# (max_loras, )
num_tokens_post_padded
:
torch
.
Tensor
|
None
,
# (max_loras, )
token_lora_mapping
:
torch
.
Tensor
,
top_k_num
:
int
,
top_k_num
:
int
,
lora_ids
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
...
@@ -375,11 +483,14 @@ def _fused_moe_lora_expand(
...
@@ -375,11 +483,14 @@ def _fused_moe_lora_expand(
"launch_pdl"
:
use_gdc
,
# triton kernel metadata
"launch_pdl"
:
use_gdc
,
# triton kernel metadata
}
}
grid_lora_dim
,
stride_tl
,
stride_el
=
_adjust_kernel_inputs
(
w1_lora_b_stacked
.
shape
[
0
],
sorted_token_ids
,
expert_ids
)
grid
=
lambda
META
:
(
grid
=
lambda
META
:
(
triton
.
cdiv
(
EM
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
triton
.
cdiv
(
EM
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
len
(
lora_b_stacked
),
len
(
lora_b_stacked
),
## max_loras + 1 to handle the no-lora case (lora_id == -1)
grid_lora_dim
,
lora_b_stacked
[
0
].
shape
[
0
]
+
1
,
)
)
# Fast path: directly accumulate into the corresponding slice interval of output.
# Fast path: directly accumulate into the corresponding slice interval of output.
...
@@ -394,11 +505,13 @@ def _fused_moe_lora_expand(
...
@@ -394,11 +505,13 @@ def _fused_moe_lora_expand(
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
token_lora_mapping
,
N
,
N
,
K
,
K
,
EM
,
EM
,
num_tokens
,
num_tokens
,
num_experts
,
num_experts
,
top_k_num
,
lora_ids
,
lora_ids
,
adapter_enabled
,
adapter_enabled
,
lora_b_stacked
[
0
].
shape
[
0
],
lora_b_stacked
[
0
].
shape
[
0
],
...
@@ -410,13 +523,14 @@ def _fused_moe_lora_expand(
...
@@ -410,13 +523,14 @@ def _fused_moe_lora_expand(
w1_lora_b_stacked
.
stride
(
2
),
w1_lora_b_stacked
.
stride
(
2
),
out_view
.
stride
(
1
),
out_view
.
stride
(
1
),
out_view
.
stride
(
2
),
out_view
.
stride
(
2
),
s
orted_token_ids
.
stride
(
0
)
,
s
tride_tl
,
expert_ids
.
stride
(
0
)
,
stride
_el
,
slice_a_size
=
a_intermediate_cache1
.
numel
()
//
num_slices
,
slice_a_size
=
a_intermediate_cache1
.
numel
()
//
num_slices
,
slice_c_size
=
slice_c_size
,
slice_c_size
=
slice_c_size
,
num_slice_a
=
num_slices
,
num_slice_a
=
num_slices
,
num_slice_c
=
num_slices
,
num_slice_c
=
num_slices
,
top_k
=
1
,
token_mapping_factor
=
1
,
naive_block_assignment
=
sorted_token_ids
is
None
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
ADD_INPUTS
=
True
,
ADD_INPUTS
=
True
,
USE_B_L2_CACHE
=
True
,
# new
USE_B_L2_CACHE
=
True
,
# new
...
@@ -436,9 +550,10 @@ def _fused_moe_lora(
...
@@ -436,9 +550,10 @@ def _fused_moe_lora(
torch
.
Tensor
torch
.
Tensor
],
# [(max_loras, num_experts, N, max_lora_rank,),...]
],
# [(max_loras, num_experts, N, max_lora_rank,),...]
topk_weights
:
torch
.
Tensor
,
# (num_tokens, top_k_num)
topk_weights
:
torch
.
Tensor
,
# (num_tokens, top_k_num)
sorted_token_ids
:
torch
.
Tensor
,
# (max_loras, _)
sorted_token_ids
:
torch
.
Tensor
|
None
,
# (max_loras, _)
expert_ids
:
torch
.
Tensor
,
# (max_loras, _ ,)
expert_ids
:
torch
.
Tensor
,
# (max_loras, _ ,) or (num_tokens * top_k,)
num_tokens_post_padded
:
torch
.
Tensor
,
# (max_loras, )
num_tokens_post_padded
:
torch
.
Tensor
|
None
,
# (max_loras, )
token_lora_mapping
:
torch
.
Tensor
,
max_lora_rank
:
int
,
max_lora_rank
:
int
,
top_k_num
:
int
,
top_k_num
:
int
,
lora_ids
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
...
@@ -462,6 +577,12 @@ def _fused_moe_lora(
...
@@ -462,6 +577,12 @@ def _fused_moe_lora(
offset
:
int
=
0
,
offset
:
int
=
0
,
)
->
None
:
)
->
None
:
assert
len
(
lora_a_stacked
)
==
len
(
lora_b_stacked
)
>
0
assert
len
(
lora_a_stacked
)
==
len
(
lora_b_stacked
)
>
0
assert
topk_weights
.
dim
()
==
qcurr_hidden_states
.
dim
()
==
2
if
sorted_token_ids
is
None
:
assert
expert_ids
.
dim
()
==
1
else
:
assert
sorted_token_ids
is
not
None
assert
num_tokens_post_padded
is
not
None
assert
(
assert
(
sorted_token_ids
.
dim
()
sorted_token_ids
.
dim
()
==
expert_ids
.
dim
()
==
expert_ids
.
dim
()
...
@@ -482,10 +603,15 @@ def _fused_moe_lora(
...
@@ -482,10 +603,15 @@ def _fused_moe_lora(
num_experts
=
lora_a_stacked
[
0
].
shape
[
1
]
num_experts
=
lora_a_stacked
[
0
].
shape
[
1
]
N
=
max_lora_rank
N
=
max_lora_rank
M
=
topk_weights
.
shape
[
0
]
M
=
topk_weights
.
shape
[
0
]
EM
=
sorted_token_ids
.
shape
[
1
]
K
=
qcurr_hidden_states
.
shape
[
1
]
K
=
qcurr_hidden_states
.
shape
[
1
]
num_tokens
=
M
*
top_k_num
num_tokens
=
M
*
top_k_num
w1_output_dim_size
=
w1_lora_b_stacked
.
shape
[
2
]
w1_output_dim_size
=
w1_lora_b_stacked
.
shape
[
2
]
assert
shrink_block_size_m
==
expand_block_size_m
EM
=
(
sorted_token_ids
.
shape
[
1
]
if
sorted_token_ids
is
not
None
else
num_tokens
*
shrink_block_size_m
)
a_intermediate_cache1
=
torch
.
zeros
(
a_intermediate_cache1
=
torch
.
zeros
(
(
num_slices
,
M
,
top_k_num
,
max_lora_rank
),
(
num_slices
,
M
,
top_k_num
,
max_lora_rank
),
...
@@ -502,6 +628,7 @@ def _fused_moe_lora(
...
@@ -502,6 +628,7 @@ def _fused_moe_lora(
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
token_lora_mapping
,
top_k_num
,
top_k_num
,
lora_ids
,
lora_ids
,
adapter_enabled
,
adapter_enabled
,
...
@@ -546,6 +673,7 @@ def _fused_moe_lora(
...
@@ -546,6 +673,7 @@ def _fused_moe_lora(
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
token_lora_mapping
,
top_k_num
,
top_k_num
,
lora_ids
,
lora_ids
,
adapter_enabled
,
adapter_enabled
,
...
@@ -579,9 +707,10 @@ def _fused_moe_lora_fake(
...
@@ -579,9 +707,10 @@ def _fused_moe_lora_fake(
lora_a_stacked
:
list
[
torch
.
Tensor
],
lora_a_stacked
:
list
[
torch
.
Tensor
],
lora_b_stacked
:
list
[
torch
.
Tensor
],
lora_b_stacked
:
list
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
|
None
,
expert_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
|
None
,
token_lora_mapping
:
torch
.
Tensor
,
max_lora_rank
:
int
,
max_lora_rank
:
int
,
top_k_num
:
int
,
top_k_num
:
int
,
lora_ids
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
...
@@ -610,9 +739,10 @@ def _fused_moe_lora_shrink_fake(
...
@@ -610,9 +739,10 @@ def _fused_moe_lora_shrink_fake(
qcurr_hidden_states
:
torch
.
Tensor
,
qcurr_hidden_states
:
torch
.
Tensor
,
lora_a_stacked
:
list
[
torch
.
Tensor
],
lora_a_stacked
:
list
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
|
None
,
expert_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
|
None
,
token_lora_mapping
:
torch
.
Tensor
,
top_k_num
:
int
,
top_k_num
:
int
,
lora_ids
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
...
@@ -642,9 +772,10 @@ def _fused_moe_lora_expand_fake(
...
@@ -642,9 +772,10 @@ def _fused_moe_lora_expand_fake(
a_intermediate_cache1
:
torch
.
Tensor
,
a_intermediate_cache1
:
torch
.
Tensor
,
lora_b_stacked
:
list
[
torch
.
Tensor
],
lora_b_stacked
:
list
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
|
None
,
expert_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
|
None
,
token_lora_mapping
:
torch
.
Tensor
,
top_k_num
:
int
,
top_k_num
:
int
,
lora_ids
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
...
...
vllm/lora/punica_wrapper/punica_base.py
View file @
7320ca39
...
@@ -458,7 +458,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
...
@@ -458,7 +458,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
adapter_enabled
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
pad_sorted_ids
:
bool
=
False
,
pad_sorted_ids
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
mixture-of-experts (MoE) execution.
...
@@ -473,9 +473,9 @@ class PunicaWrapperBase(PunicaWrapperABC):
...
@@ -473,9 +473,9 @@ class PunicaWrapperBase(PunicaWrapperABC):
lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...],
lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...],
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
|
None
,
expert_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
|
None
,
max_lora_rank
:
int
,
max_lora_rank
:
int
,
top_k_num
:
int
,
top_k_num
:
int
,
shrink_config
,
shrink_config
,
...
@@ -484,6 +484,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
...
@@ -484,6 +484,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
mul_routed_weight
=
False
,
mul_routed_weight
=
False
,
fully_sharded
:
bool
=
False
,
fully_sharded
:
bool
=
False
,
offset
:
int
=
0
,
offset
:
int
=
0
,
token_lora_mapping
:
torch
.
Tensor
|
None
=
None
,
):
):
"""
"""
Performs a fused forward computation for LoRA of
Performs a fused forward computation for LoRA of
...
...
vllm/lora/punica_wrapper/punica_gpu.py
View file @
7320ca39
...
@@ -310,11 +310,20 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -310,11 +310,20 @@ class PunicaWrapperGPU(PunicaWrapperBase):
adapter_enabled
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
pad_sorted_ids
:
bool
=
False
,
pad_sorted_ids
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
naive_block_assignment
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
mixture-of-experts (MoE) execution.
"""
"""
(
token_lora_mapping
,
_
,
_
,
_
,
lora_ids
,
_
)
=
self
.
token_mapping_meta
.
meta_args
(
num_tokens
)
if
naive_block_assignment
:
expert_ids
=
topk_ids
.
reshape
(
-
1
)
sorted_ids
=
None
num_tokens_post_pad
=
None
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
pad_sorted_ids
:
if
pad_sorted_ids
:
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
...
@@ -334,10 +343,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -334,10 +343,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
(
max_loras
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
max_loras
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
(
token_lora_mapping
,
_
,
_
,
_
,
lora_ids
,
_
)
=
self
.
token_mapping_meta
.
meta_args
(
num_tokens
)
ops
.
moe_lora_align_block_size
(
ops
.
moe_lora_align_block_size
(
topk_ids
,
topk_ids
,
token_lora_mapping
,
token_lora_mapping
,
...
@@ -355,7 +360,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -355,7 +360,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
return
None
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
add_lora_fused_moe
(
def
add_lora_fused_moe
(
self
,
self
,
...
@@ -364,9 +369,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -364,9 +369,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...],
lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...],
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
|
None
,
expert_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
|
None
,
max_lora_rank
:
int
,
max_lora_rank
:
int
,
top_k_num
:
int
,
top_k_num
:
int
,
shrink_config
,
shrink_config
,
...
@@ -375,11 +380,21 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -375,11 +380,21 @@ class PunicaWrapperGPU(PunicaWrapperBase):
mul_routed_weight
=
False
,
mul_routed_weight
=
False
,
fully_sharded
:
bool
=
False
,
fully_sharded
:
bool
=
False
,
offset
:
int
=
0
,
offset
:
int
=
0
,
token_lora_mapping
:
torch
.
Tensor
|
None
=
None
,
):
):
"""
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
"""
(
_
,
_
,
_
,
_
,
lora_ids
,
_
)
=
self
.
token_mapping_meta
.
meta_args
(
x
.
size
(
0
))
(
token_lora_mapping_meta
,
_
,
_
,
_
,
lora_ids
,
_
,
)
=
self
.
token_mapping_meta
.
meta_args
(
x
.
size
(
0
))
if
token_lora_mapping
is
None
:
token_lora_mapping
=
token_lora_mapping_meta
fused_moe_lora
(
fused_moe_lora
(
y
,
y
,
x
,
x
,
...
@@ -389,6 +404,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -389,6 +404,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
token_lora_mapping
,
max_lora_rank
,
max_lora_rank
,
top_k_num
,
top_k_num
,
lora_ids
,
lora_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment