Unverified Commit 9f698563 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] register punica functions as torch ops (#7591)

parent d4f0f17b
...@@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving. ...@@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import Dict, Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -86,14 +84,13 @@ def _bgmv_expand_kernel( ...@@ -86,14 +84,13 @@ def _bgmv_expand_kernel(
@torch.inference_mode() @torch.inference_mode()
def bgmv_expand( def _bgmv_expand(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_b_weights: torch.Tensor, lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
add_inputs: bool = True, add_inputs: bool = True,
override_config: Optional[Dict[str, int]] = None, ) -> None:
):
""" """
Args: Args:
inputs (torch.Tensor): input tensor inputs (torch.Tensor): input tensor
...@@ -105,10 +102,7 @@ def bgmv_expand( ...@@ -105,10 +102,7 @@ def bgmv_expand(
batches (int): batch size batches (int): batch size
add_inputs (bool, optional): Defaults to False. adds the final lora add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output. results to the output.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
""" """
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [ assert lora_b_weights.dtype in [
torch.float16, torch.float16,
...@@ -138,10 +132,7 @@ def bgmv_expand( ...@@ -138,10 +132,7 @@ def bgmv_expand(
]: ]:
CAST_TYPE = True CAST_TYPE = True
batches = lora_indices_tensor.size(0) batches = lora_indices_tensor.size(0)
if override_config: config = get_lora_op_configs("expand", batches, N)
config = override_config
else:
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: ( grid = lambda META: (
META["SPLIT_N"], META["SPLIT_N"],
batches, batches,
...@@ -167,3 +158,8 @@ def bgmv_expand( ...@@ -167,3 +158,8 @@ def bgmv_expand(
**config, **config,
) )
return return
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
...@@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving. ...@@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import Dict, Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -89,7 +87,7 @@ def _bgmv_expand_slice_kernel( ...@@ -89,7 +87,7 @@ def _bgmv_expand_slice_kernel(
@torch.inference_mode() @torch.inference_mode()
def bgmv_expand_slice( def _bgmv_expand_slice(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_b_weights: torch.Tensor, lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
...@@ -97,8 +95,7 @@ def bgmv_expand_slice( ...@@ -97,8 +95,7 @@ def bgmv_expand_slice(
slice_offset: int, slice_offset: int,
slice_size: int, slice_size: int,
add_inputs: bool = True, add_inputs: bool = True,
override_config: Optional[Dict[str, int]] = None, ) -> None:
):
""" """
Args: Args:
inputs (torch.Tensor): input tensor inputs (torch.Tensor): input tensor
...@@ -111,10 +108,7 @@ def bgmv_expand_slice( ...@@ -111,10 +108,7 @@ def bgmv_expand_slice(
slice_size (int): current output_tensor's size slice_size (int): current output_tensor's size
batches (int): batch size batches (int): batch size
add_inputs (bool, optional): Defaults to False. add_inputs (bool, optional): Defaults to False.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
""" """
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [ assert lora_b_weights.dtype in [
torch.float16, torch.float16,
...@@ -149,10 +143,7 @@ def bgmv_expand_slice( ...@@ -149,10 +143,7 @@ def bgmv_expand_slice(
batches = lora_indices_tensor.size(0) batches = lora_indices_tensor.size(0)
if override_config: config = get_lora_op_configs("expand", batches, N)
config = override_config
else:
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: ( grid = lambda META: (
META["SPLIT_N"], META["SPLIT_N"],
...@@ -180,3 +171,8 @@ def bgmv_expand_slice( ...@@ -180,3 +171,8 @@ def bgmv_expand_slice(
**config, **config,
) )
return return
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
...@@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving. ...@@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import Dict, Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -78,14 +76,13 @@ def _bgmv_shrink_kernel( ...@@ -78,14 +76,13 @@ def _bgmv_shrink_kernel(
@torch.inference_mode() @torch.inference_mode()
def bgmv_shrink( def _bgmv_shrink(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_a_weights: torch.Tensor, lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
scaling: float = 1.0, scaling: float = 1.0,
override_config: Optional[Dict[str, int]] = None, ) -> None:
):
""" """
Args: Args:
inputs (torch.Tensor): input tensor inputs (torch.Tensor): input tensor
...@@ -96,8 +93,6 @@ def bgmv_shrink( ...@@ -96,8 +93,6 @@ def bgmv_shrink(
applied. applied.
batches (int): batch size batches (int): batch size
scaling (float): Scaling factor. scaling (float): Scaling factor.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
""" """
assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16] assert inputs.dtype in [torch.float16, torch.bfloat16]
...@@ -119,11 +114,8 @@ def bgmv_shrink( ...@@ -119,11 +114,8 @@ def bgmv_shrink(
batches = lora_indices_tensor.size(0) batches = lora_indices_tensor.size(0)
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
BLOCK_N = triton.next_power_of_2(N) BLOCK_N = triton.next_power_of_2(N)
if override_config: # First try to load optimal config from the file
config = override_config config = get_lora_op_configs("bgmv_shrink", batches, K)
else:
# First try to load optimal config from the file
config = get_lora_op_configs("bgmv_shrink", batches, K)
grid = lambda META: ( grid = lambda META: (
META["SPLIT_K"], META["SPLIT_K"],
...@@ -148,3 +140,8 @@ def bgmv_shrink( ...@@ -148,3 +140,8 @@ def bgmv_shrink(
**config, **config,
) )
return return
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
...@@ -97,7 +97,7 @@ def _sgmv_expand_kernel( ...@@ -97,7 +97,7 @@ def _sgmv_expand_kernel(
@torch.inference_mode() @torch.inference_mode()
def sgmv_expand( def _sgmv_expand(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_b_weights: torch.Tensor, lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
...@@ -107,7 +107,7 @@ def sgmv_expand( ...@@ -107,7 +107,7 @@ def sgmv_expand(
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
add_inputs: bool = False, add_inputs: bool = False,
): ) -> None:
""" """
Args: Args:
inputs (torch.Tensor): input tensor inputs (torch.Tensor): input tensor
...@@ -190,3 +190,8 @@ def sgmv_expand( ...@@ -190,3 +190,8 @@ def sgmv_expand(
CAST_TYPE, CAST_TYPE,
) )
return return
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])
...@@ -103,7 +103,7 @@ def _sgmv_expand_slice_kernel( ...@@ -103,7 +103,7 @@ def _sgmv_expand_slice_kernel(
@torch.inference_mode() @torch.inference_mode()
def sgmv_expand_slice( def _sgmv_expand_slice(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_b_weights: torch.Tensor, lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
...@@ -115,7 +115,7 @@ def sgmv_expand_slice( ...@@ -115,7 +115,7 @@ def sgmv_expand_slice(
slice_offset: int, slice_offset: int,
slice_size: int, slice_size: int,
add_inputs: bool = False, add_inputs: bool = False,
): ) -> None:
"""_summary_ """_summary_
Args: Args:
...@@ -203,3 +203,8 @@ def sgmv_expand_slice( ...@@ -203,3 +203,8 @@ def sgmv_expand_slice(
CAST_TYPE, CAST_TYPE,
) )
return return
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
...@@ -101,7 +101,7 @@ def _sgmv_shrink_kernel( ...@@ -101,7 +101,7 @@ def _sgmv_shrink_kernel(
@torch.inference_mode() @torch.inference_mode()
def sgmv_shrink( def _sgmv_shrink(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_a_weights: torch.Tensor, lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
...@@ -111,7 +111,7 @@ def sgmv_shrink( ...@@ -111,7 +111,7 @@ def sgmv_shrink(
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
scaling: float, scaling: float,
): ) -> None:
""" """
Args: Args:
...@@ -187,3 +187,8 @@ def sgmv_shrink( ...@@ -187,3 +187,8 @@ def sgmv_shrink(
SPLIT_K, SPLIT_K,
) )
return return
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment