Unverified Commit 3ecd0bf9 authored by gnovack's avatar gnovack Committed by GitHub
Browse files

Add TMA support to fused_moe_lora kernel (#32195)


Signed-off-by: default avatargnovack <gnovack@amazon.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent e3eb146f
...@@ -231,17 +231,22 @@ def use_torch( ...@@ -231,17 +231,22 @@ def use_torch(
lora_a_stacked, lora_a_stacked,
lora_b_stacked, lora_b_stacked,
top_k_num, top_k_num,
num_slices=1,
): ):
outputs = [] outputs = []
for i in range(hidden_states.shape[0]): for i in range(hidden_states.shape[0]):
slice_tensors = []
for slice_id in range(num_slices):
lora_idx = token_lora_mapping[i] lora_idx = token_lora_mapping[i]
expert_ids = topk_ids[i] expert_ids = topk_ids[i]
lora_a = lora_a_stacked[0][lora_idx][expert_ids] lora_a = lora_a_stacked[slice_id][lora_idx][expert_ids]
lora_b = lora_b_stacked[0][lora_idx][expert_ids] lora_b = lora_b_stacked[slice_id][lora_idx][expert_ids]
tensors = [ tensors = [
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num) hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
] ]
outputs.append(torch.stack(tensors, dim=0)) slice_tensors.append(torch.stack(tensors, dim=0))
outputs.append(torch.concat(slice_tensors, dim=-1))
return torch.stack(outputs, dim=0) return torch.stack(outputs, dim=0)
...@@ -259,6 +264,7 @@ SEED = [42] ...@@ -259,6 +264,7 @@ SEED = [42]
@pytest.mark.parametrize("K", [2048]) @pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64]) @pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_slices", [1, 2])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("seed", SEED)
...@@ -271,6 +277,7 @@ def test_fused_moe_lora_kernel( ...@@ -271,6 +277,7 @@ def test_fused_moe_lora_kernel(
K, K,
max_lora_rank, max_lora_rank,
block_size, block_size,
num_slices,
dtype, dtype,
device, device,
seed, seed,
...@@ -295,17 +302,19 @@ def test_fused_moe_lora_kernel( ...@@ -295,17 +302,19 @@ def test_fused_moe_lora_kernel(
), ),
dtype=dtype, dtype=dtype,
) )
for _ in range(num_slices)
] ]
lora_b_stacked = [ lora_b_stacked = [
torch.rand( torch.rand(
( (
max_loras, max_loras,
num_experts, num_experts,
N, N // num_slices,
max_lora_rank, max_lora_rank,
), ),
dtype=dtype, dtype=dtype,
) )
for _ in range(num_slices)
] ]
hidden_states = torch.rand( hidden_states = torch.rand(
( (
...@@ -340,6 +349,7 @@ def test_fused_moe_lora_kernel( ...@@ -340,6 +349,7 @@ def test_fused_moe_lora_kernel(
lora_a_stacked, lora_a_stacked,
lora_b_stacked, lora_b_stacked,
top_k_num, top_k_num,
num_slices,
) )
torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2) torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2)
...@@ -434,6 +444,7 @@ def use_fused_moe_lora_kernel_naive( ...@@ -434,6 +444,7 @@ def use_fused_moe_lora_kernel_naive(
@pytest.mark.parametrize("K", [2048]) @pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32]) @pytest.mark.parametrize("max_lora_rank", [16, 32])
@pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_slices", [1, 2])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("seed", SEED)
...@@ -446,6 +457,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment( ...@@ -446,6 +457,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
K, K,
max_lora_rank, max_lora_rank,
block_size, block_size,
num_slices,
dtype, dtype,
device, device,
seed, seed,
...@@ -484,17 +496,19 @@ def test_fused_moe_lora_kernel_naive_block_assignment( ...@@ -484,17 +496,19 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
), ),
dtype=dtype, dtype=dtype,
) )
for _ in range(num_slices)
] ]
lora_b_stacked = [ lora_b_stacked = [
torch.rand( torch.rand(
( (
max_loras, max_loras,
num_experts, num_experts,
N, N // num_slices,
max_lora_rank, max_lora_rank,
), ),
dtype=dtype, dtype=dtype,
) )
for _ in range(num_slices)
] ]
hidden_states = torch.rand( hidden_states = torch.rand(
( (
...@@ -529,6 +543,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment( ...@@ -529,6 +543,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
lora_a_stacked, lora_a_stacked,
lora_b_stacked, lora_b_stacked,
top_k_num, top_k_num,
num_slices,
) )
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)
......
...@@ -2,7 +2,11 @@ ...@@ -2,7 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import shutil
import pytest import pytest
import torch
from safetensors.torch import load_file, save_file
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -122,6 +126,41 @@ def test_olmoe_lora_mixed(olmoe_lora_files): ...@@ -122,6 +126,41 @@ def test_olmoe_lora_mixed(olmoe_lora_files):
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None]) generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
def test_olmoe_lora_mixed_random(olmoe_lora_files, tmp_path):
# Create a dummy LoRA with random weights based on the real one
random_lora_path = tmp_path / "random_lora"
shutil.copytree(olmoe_lora_files, random_lora_path)
weights_path = random_lora_path / "adapter_model.safetensors"
weights = load_file(str(weights_path))
random_weights = {k: torch.randn_like(v) for k, v in weights.items()}
save_file(random_weights, str(weights_path))
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
prompts = [
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
]
lora_requests = [
LoRARequest("real", 1, olmoe_lora_files),
LoRARequest("random", 2, str(random_lora_path)),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
assert outputs[0].outputs[0].text.strip().startswith(EXPECTED_LORA_OUTPUT[0])
@pytest.mark.parametrize("fully_sharded_loras", [False, True]) @pytest.mark.parametrize("fully_sharded_loras", [False, True])
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras): def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras):
......
...@@ -8,9 +8,10 @@ from vllm.distributed import ( ...@@ -8,9 +8,10 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.triton_utils.allocation import set_triton_allocator
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from .utils import supports_pdl from .utils import supports_pdl, supports_tma
@triton.jit @triton.jit
...@@ -70,6 +71,37 @@ def _get_token_offs( ...@@ -70,6 +71,37 @@ def _get_token_offs(
) )
@triton.jit
def _get_c_ptrs(
cur_c_ptr,
lora_id,
pid_m,
offs,
offs_token,
offs_cn,
stride_cm,
stride_cn,
EM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
sort_c: tl.constexpr,
):
# When sort_c is true, store the output in c_ptr using token order defined
# in sorted_token_ids_ptr; otherwise, use the original token order from the prompt
if sort_c:
offs_token_id = pid_m * BLOCK_SIZE_M + offs
c_ptrs = (
cur_c_ptr
+ lora_id * EM * stride_cm
+ stride_cm * offs_token_id[:, None]
+ stride_cn * offs_cn[None, :]
)
else:
c_ptrs = (
cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
)
return c_ptrs
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
...@@ -125,7 +157,9 @@ def _adjust_kernel_inputs( ...@@ -125,7 +157,9 @@ def _adjust_kernel_inputs(
) )
def _fused_moe_lora_kernel( def _fused_moe_lora_kernel(
a_ptr, a_ptr,
a_desc,
b_ptr, b_ptr,
b_desc,
c_ptr, c_ptr,
topk_weights_ptr, topk_weights_ptr,
sorted_token_ids_ptr, sorted_token_ids_ptr,
...@@ -177,6 +211,18 @@ def _fused_moe_lora_kernel( ...@@ -177,6 +211,18 @@ def _fused_moe_lora_kernel(
USE_GDC: tl.constexpr, USE_GDC: tl.constexpr,
launch_pdl: tl.constexpr, launch_pdl: tl.constexpr,
IS_PRIMARY: tl.constexpr, IS_PRIMARY: tl.constexpr,
USE_TMA: tl.constexpr,
# sort_c determines whether tokens are stored in C in the order determined
# by sorted_token_ids to enable later TMA loads from this tensor.
#
# When USE_TMA is enabled, the parameter combinations are:
# a_desc | b_desc | sort_c | Use Case
# --------|---------|--------|-----------------------------
# yes | yes | False | expand kernel (num_slices=1)
# no | yes | True | shrink kernel (num_slices=1)
# yes | no | False | expand kernel (num_slices>1)
# no | no | True | shrink kernel (num_slices>1)
sort_c: tl.constexpr,
): ):
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1) slice_id = tl.program_id(axis=1)
...@@ -250,17 +296,43 @@ def _fused_moe_lora_kernel( ...@@ -250,17 +296,43 @@ def _fused_moe_lora_kernel(
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
# remove modulo wrap-around
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
# get a_ptrs,b_ptrs if USE_TMA and a_desc is not None:
# Expand path - with TMA enabled, load from A using TMA descriptor
offs_am = (
slice_id * max_loras * EM
+ lora_id * EM
+ pid_m * BLOCK_SIZE_M // token_mapping_factor
)
offs_ak = pid_sk * BLOCK_SIZE_K
else:
# Shrink path - load hidden states based on order defined in
# 'sorted_token_ids_ptr' then store them in c_ptr in this same sorted order
tl.static_assert(a_desc is None, "a_desc must be none")
a_ptrs = cur_a_ptr + ( a_ptrs = cur_a_ptr + (
offs_token[:, None] // token_mapping_factor * stride_am offs_token[:, None] // token_mapping_factor * stride_am
+ offs_k[None, :] * stride_ak + offs_k[None, :] * stride_ak
) )
if USE_TMA:
offs_bn = pid_n * BLOCK_SIZE_N
offs_bk = pid_sk * BLOCK_SIZE_K
if b_desc is None:
# Note(@gnovack) - Allocation of TMA descriptors on-device
# can cause conflicts when running in parallel via PDL
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
b_desc = tl.make_tensor_descriptor(
cur_b_ptr,
shape=[max_loras, num_experts, N, K],
strides=[stride_bl, stride_be, stride_bn, stride_bk],
block_shape=[1, 1, BLOCK_SIZE_N, BLOCK_SIZE_K],
)
else:
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
b_ptrs = ( b_ptrs = (
cur_b_ptr cur_b_ptr
+ lora_id * stride_bl + lora_id * stride_bl
...@@ -273,35 +345,41 @@ def _fused_moe_lora_kernel( ...@@ -273,35 +345,41 @@ def _fused_moe_lora_kernel(
# GDC launch dependents hints the runtime system to launch dependent kernels. # GDC launch dependents hints the runtime system to launch dependent kernels.
tl.extra.cuda.gdc_launch_dependents() tl.extra.cuda.gdc_launch_dependents()
# accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if USE_GDC and not IS_PRIMARY: if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait() tl.extra.cuda.gdc_wait()
for k in range(0, grid_k): for k in range(0, grid_k):
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) cur_k_offset = k * (BLOCK_SIZE_K * SPLIT_K)
# GDC wait waits for ALL programs in the prior kernel to complete k_remaining = K - cur_k_offset
# before continuing.
# pre-fetch lora weight # pre-fetch lora weight
if b_desc is not None:
b = (
b_desc.load([lora_id, expert_id, offs_bn, offs_bk + cur_k_offset])
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
.T
)
else:
# add (offs_bn < N) mask; optional .ca for B # add (offs_bn < N) mask; optional .ca for B
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N) b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
if USE_B_L2_CACHE: if USE_B_L2_CACHE:
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca") b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
else: else:
b = tl.load(b_ptrs, mask=b_mask, other=0.0) b = tl.load(b_ptrs, mask=b_mask, other=0.0)
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
if USE_GDC and not IS_PRIMARY: if a_desc is not None:
tl.extra.cuda.gdc_wait() a = a_desc.load([offs_am, offs_ak + cur_k_offset])
else:
a = tl.load( a = tl.load(
a_ptrs, a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
other=0.0, other=0.0,
) )
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
accumulator += tl.dot(a, b)
if MUL_ROUTED_WEIGHT: if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
...@@ -309,7 +387,19 @@ def _fused_moe_lora_kernel( ...@@ -309,7 +387,19 @@ def _fused_moe_lora_kernel(
accumulator = accumulator.to(c_ptr.dtype.element_ty) accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_ptrs = _get_c_ptrs(
cur_c_ptr,
lora_id,
pid_m,
offs,
offs_token,
offs_cn,
stride_cm,
stride_cn,
EM,
BLOCK_SIZE_M,
sort_c,
)
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
if SPLIT_K == 1: if SPLIT_K == 1:
...@@ -357,6 +447,7 @@ def _fused_moe_lora_shrink( ...@@ -357,6 +447,7 @@ def _fused_moe_lora_shrink(
num_active_loras: int, num_active_loras: int,
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
use_gdc: bool = False, use_gdc: bool = False,
use_tma: bool = False,
) -> None: ) -> None:
w1_lora_a_stacked = lora_a_stacked[0] w1_lora_a_stacked = lora_a_stacked[0]
shrink_config = { shrink_config = {
...@@ -369,6 +460,7 @@ def _fused_moe_lora_shrink( ...@@ -369,6 +460,7 @@ def _fused_moe_lora_shrink(
"SPLIT_K": split_k, "SPLIT_K": split_k,
"USE_GDC": use_gdc, "USE_GDC": use_gdc,
"launch_pdl": use_gdc, # triton kernel metadata "launch_pdl": use_gdc, # triton kernel metadata
"USE_TMA": use_tma,
} }
b_ptr = _get_ptr(lora_a_stacked, device) b_ptr = _get_ptr(lora_a_stacked, device)
...@@ -383,9 +475,20 @@ def _fused_moe_lora_shrink( ...@@ -383,9 +475,20 @@ def _fused_moe_lora_shrink(
len(lora_a_stacked), len(lora_a_stacked),
grid_lora_dim, grid_lora_dim,
) )
a_desc = None
b_desc = None
if use_tma and num_slices == 1:
b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
lora_a_stacked[0],
[1, 1, shrink_config["BLOCK_SIZE_N"], shrink_config["BLOCK_SIZE_K"]],
)
_fused_moe_lora_kernel[grid]( _fused_moe_lora_kernel[grid](
qcurr_hidden_states, qcurr_hidden_states,
a_desc,
b_ptr, b_ptr,
b_desc,
a_intermediate_cache1, a_intermediate_cache1,
topk_weights, topk_weights,
sorted_token_ids, sorted_token_ids,
...@@ -407,8 +510,8 @@ def _fused_moe_lora_shrink( ...@@ -407,8 +510,8 @@ def _fused_moe_lora_shrink(
w1_lora_a_stacked.stride(1), w1_lora_a_stacked.stride(1),
w1_lora_a_stacked.stride(3), w1_lora_a_stacked.stride(3),
w1_lora_a_stacked.stride(2), w1_lora_a_stacked.stride(2),
a_intermediate_cache1.stride(2), a_intermediate_cache1.stride(-2),
a_intermediate_cache1.stride(3), a_intermediate_cache1.stride(-1),
stride_tl, stride_tl,
stride_el, stride_el,
slice_a_size=qcurr_hidden_states.numel(), slice_a_size=qcurr_hidden_states.numel(),
...@@ -419,7 +522,8 @@ def _fused_moe_lora_shrink( ...@@ -419,7 +522,8 @@ def _fused_moe_lora_shrink(
naive_block_assignment=sorted_token_ids is None, naive_block_assignment=sorted_token_ids is None,
MUL_ROUTED_WEIGHT=False, MUL_ROUTED_WEIGHT=False,
ADD_INPUTS=False, ADD_INPUTS=False,
USE_B_L2_CACHE=True, # new USE_B_L2_CACHE=True,
sort_c=use_tma and sorted_token_ids is not None,
IS_PRIMARY=True, IS_PRIMARY=True,
**shrink_config, **shrink_config,
) )
...@@ -462,6 +566,7 @@ def _fused_moe_lora_expand( ...@@ -462,6 +566,7 @@ def _fused_moe_lora_expand(
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
offset: int = 0, offset: int = 0,
use_gdc: bool = False, use_gdc: bool = False,
use_tma: bool = False,
) -> None: ) -> None:
b_ptr = _get_ptr(lora_b_stacked, device) b_ptr = _get_ptr(lora_b_stacked, device)
K = max_lora_rank K = max_lora_rank
...@@ -470,7 +575,7 @@ def _fused_moe_lora_expand( ...@@ -470,7 +575,7 @@ def _fused_moe_lora_expand(
w1_lora_b_stacked = lora_b_stacked[0] w1_lora_b_stacked = lora_b_stacked[0]
a_intermediate_cache1 = a_intermediate_cache1.view( a_intermediate_cache1 = a_intermediate_cache1.view(
-1, a_intermediate_cache1.shape[3] -1, a_intermediate_cache1.shape[-1]
) )
expand_config = { expand_config = {
...@@ -483,6 +588,7 @@ def _fused_moe_lora_expand( ...@@ -483,6 +588,7 @@ def _fused_moe_lora_expand(
"SPLIT_K": 1, # Set split_k = 1 for expand calls "SPLIT_K": 1, # Set split_k = 1 for expand calls
"USE_GDC": use_gdc, "USE_GDC": use_gdc,
"launch_pdl": use_gdc, # triton kernel metadata "launch_pdl": use_gdc, # triton kernel metadata
"USE_TMA": use_tma,
} }
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs( grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
...@@ -498,10 +604,27 @@ def _fused_moe_lora_expand( ...@@ -498,10 +604,27 @@ def _fused_moe_lora_expand(
# Fast path: directly accumulate into the corresponding slice interval of output. # Fast path: directly accumulate into the corresponding slice interval of output.
out_view = output[:, :, offset : offset + num_slices * N] out_view = output[:, :, offset : offset + num_slices * N]
slice_c_size = N * out_view.stride(2) slice_c_size = N * out_view.stride(2)
a_desc = None
b_desc = None
if use_tma:
if sorted_token_ids is not None:
a_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
a_intermediate_cache1,
[expand_config["BLOCK_SIZE_M"], expand_config["BLOCK_SIZE_K"]],
)
if num_slices == 1:
b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
lora_b_stacked[0],
[1, 1, expand_config["BLOCK_SIZE_N"], expand_config["BLOCK_SIZE_K"]],
)
else:
b_desc = None
_fused_moe_lora_kernel[grid]( _fused_moe_lora_kernel[grid](
a_intermediate_cache1, a_intermediate_cache1,
a_desc,
b_ptr, b_ptr,
b_desc,
out_view, out_view,
topk_weights, topk_weights,
sorted_token_ids, sorted_token_ids,
...@@ -535,7 +658,8 @@ def _fused_moe_lora_expand( ...@@ -535,7 +658,8 @@ def _fused_moe_lora_expand(
naive_block_assignment=sorted_token_ids is None, naive_block_assignment=sorted_token_ids is None,
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
ADD_INPUTS=True, ADD_INPUTS=True,
USE_B_L2_CACHE=True, # new USE_B_L2_CACHE=True,
sort_c=False,
IS_PRIMARY=False, IS_PRIMARY=False,
**expand_config, **expand_config,
) )
...@@ -616,8 +740,34 @@ def _fused_moe_lora( ...@@ -616,8 +740,34 @@ def _fused_moe_lora(
else num_tokens * shrink_block_size_m else num_tokens * shrink_block_size_m
) )
# TMA is not currently compatiple with fully_sharded due to the non-determinism
# of token id sorting across ranks.
use_tma = supports_tma(device) and not fully_sharded
intermediate_cache_shape = (
num_slices,
M,
top_k_num,
max_lora_rank,
)
if use_tma:
if num_slices > 1:
# if num_slices > 1, we construct TMA descriptors for LoRA
# weights within the kernel, which requires us to first set an allocator
set_triton_allocator(device)
# When storing intermediate data in sorted order for TMA, we
# need an extra 'num_active_loras' dim in the cache to avoid conflicts
if sorted_token_ids is not None:
intermediate_cache_shape = (
num_slices,
sorted_token_ids.shape[0],
EM,
max_lora_rank,
)
a_intermediate_cache1 = torch.zeros( a_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, max_lora_rank), intermediate_cache_shape,
dtype=output.dtype, dtype=output.dtype,
device=device, device=device,
) )
...@@ -654,6 +804,7 @@ def _fused_moe_lora( ...@@ -654,6 +804,7 @@ def _fused_moe_lora(
num_active_loras, num_active_loras,
mul_routed_weight, mul_routed_weight,
use_gdc=use_gdc, use_gdc=use_gdc,
use_tma=use_tma,
) )
if fully_sharded: if fully_sharded:
...@@ -703,6 +854,7 @@ def _fused_moe_lora( ...@@ -703,6 +854,7 @@ def _fused_moe_lora(
mul_routed_weight, mul_routed_weight,
offset, offset,
use_gdc=use_gdc, use_gdc=use_gdc,
use_tma=use_tma,
) )
...@@ -772,6 +924,7 @@ def _fused_moe_lora_shrink_fake( ...@@ -772,6 +924,7 @@ def _fused_moe_lora_shrink_fake(
num_active_loras: int, num_active_loras: int,
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
use_gdc: bool = False, use_gdc: bool = False,
use_tma: bool = False,
) -> None: ) -> None:
return return
...@@ -809,6 +962,7 @@ def _fused_moe_lora_expand_fake( ...@@ -809,6 +962,7 @@ def _fused_moe_lora_expand_fake(
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
offset: int = 0, offset: int = 0,
use_gdc: bool = False, use_gdc: bool = False,
use_tma: bool = False,
) -> None: ) -> None:
return return
......
...@@ -316,3 +316,9 @@ def supports_pdl(device: torch.device | None = None) -> bool: ...@@ -316,3 +316,9 @@ def supports_pdl(device: torch.device | None = None) -> bool:
and current_platform.has_device_capability(90) and current_platform.has_device_capability(90)
and not envs.VLLM_LORA_DISABLE_PDL and not envs.VLLM_LORA_DISABLE_PDL
) )
@lru_cache
def supports_tma(device: torch.device | None = None) -> bool:
# TMA requires compute capability SM90 or above
return current_platform.is_cuda() and current_platform.has_device_capability(90)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import triton
def set_triton_allocator(device: torch.device):
def alloc_fn(size: int, alignment: int, stream: int | None):
return torch.empty(size, device=device, dtype=torch.int8)
triton.set_allocator(alloc_fn)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment