Unverified Commit 2163586e authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[feat] triton kernel for get_last_loc (#6676)

parent e06b0761
...@@ -1810,10 +1810,72 @@ def write_req_to_token_pool_triton( ...@@ -1810,10 +1810,72 @@ def write_req_to_token_pool_triton(
) )
@torch.compile(dynamic=True, backend=get_compiler_backend()) def get_last_loc(
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor): req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
if global_server_args_dict["attention_backend"] != "torch_native":
impl = get_last_loc_triton
else:
impl = get_last_loc_torch
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
def get_last_loc_torch(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
return torch.where( return torch.where(
prefix_lens_tensor > 0, prefix_lens_tensor > 0,
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
torch.full_like(prefix_lens_tensor, -1), torch.full_like(prefix_lens_tensor, -1),
) )
@triton.jit
def get_last_loc_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_mask = prefix_lens > 0
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
tl.store(result + offset, tokens, mask=mask)
def get_last_loc_triton(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
BLOCK_SIZE = 256
num_tokens = prefix_lens_tensor.shape[0]
result = torch.empty_like(prefix_lens_tensor)
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
get_last_loc_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token.stride(0),
BLOCK_SIZE,
)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment