Unverified Commit 47043eb6 authored by Tuan, Hoang-Trong's avatar Tuan, Hoang-Trong Committed by GitHub
Browse files

[Kernel] Triton implementation of causal-conv1d for Mamba-based models (#18218)


Signed-off-by: default avatarTuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: default avatarTuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: default avatarTyler Michael Smith <tysmith@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 31b96d1c
...@@ -232,7 +232,6 @@ endif() ...@@ -232,7 +232,6 @@ endif()
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu" "csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/cache_kernels.cu" "csrc/cache_kernels.cu"
"csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu" "csrc/attention/paged_attention_v2.cu"
......
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ConvParamsBase {
using index_t = uint32_t;
int batch, dim, seqlen, width;
int64_t pad_slot_id;
bool silu_activation;
index_t x_batch_stride;
index_t x_c_stride;
index_t x_l_stride;
index_t weight_c_stride;
index_t weight_width_stride;
index_t out_batch_stride;
index_t out_c_stride;
index_t out_l_stride;
int conv_state_len;
index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;
// Common data pointers.
void *__restrict__ x_ptr;
void *__restrict__ weight_ptr;
void *__restrict__ bias_ptr;
void *__restrict__ out_ptr;
void *__restrict__ conv_state_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ cache_indices_ptr;
int32_t *__restrict__ cache_seqlens;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t *__restrict__ conv_state_indices_ptr;
void *__restrict__ seq_idx_ptr;
// No __restrict__ since initial_states could be the same as final_states.
void * initial_states_ptr;
index_t initial_states_batch_stride;
index_t initial_states_l_stride;
index_t initial_states_c_stride;
void * final_states_ptr;
index_t final_states_batch_stride;
index_t final_states_l_stride;
index_t final_states_c_stride;
void * conv_states_ptr;
index_t conv_states_batch_stride;
index_t conv_states_l_stride;
index_t conv_states_c_stride;
};
#ifndef USE_ROCM
#include <cuda_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor_sync(uint32_t(-1), val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return std::max(ilist);
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return std::min(a, b);
}
#else
#include <hip/hip_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor(val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return *std::max_element(ilist.begin(), ilist.end());
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return a < b ? a : b;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES> struct BytesToType {};
template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
...@@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, ...@@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const std::optional<torch::Tensor>& has_initial_state, const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id); const torch::Tensor& ssm_states, int64_t pad_slot_id);
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& weight,
const std::optional<at::Tensor>& bias_,
bool silu_activation,
const std::optional<at::Tensor>& cache_seqlens_,
const std::optional<at::Tensor>& conv_state_indices_,
int64_t pad_slot_id);
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const std::optional<at::Tensor>& bias_,
const std::optional<at::Tensor>& conv_states,
const std::optional<at::Tensor>& query_start_loc,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
bool silu_activation, int64_t pad_slot_id);
using fptr_t = int64_t; using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs, fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank, torch::Tensor& rank_data, int64_t rank,
......
...@@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int pad_slot_id) -> ()"); "int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#ifndef USE_ROCM #ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def( ops.def(
......
...@@ -6,9 +6,8 @@ from typing import Optional ...@@ -6,9 +6,8 @@ from typing import Optional
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
...@@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, ...@@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
x = x.contiguous() x = x.contiguous()
bias = bias.contiguous() if bias is not None else None bias = bias.contiguous() if bias is not None else None
opcheck(torch.ops._C.causal_conv1d_fwd,
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("has_initial_state", [True, False])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
has_initial_state, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
current_platform.seed_everything(0)
x = torch.randn(batch, dim, seqlen, device=device,
dtype=itype).contiguous()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
if has_initial_state:
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
has_initial_state_tensor = torch.ones(batch,
dtype=torch.bool,
device=x.device)
else:
initial_states = None
has_initial_state_tensor = None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
initial_states_ref = initial_states.clone(
) if initial_states is not None else None
activation = None if not silu_activation else "silu"
out = causal_conv1d_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=has_initial_state_tensor)
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=True,
activation=activation)
if has_initial_state:
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
causal_conv1d_opcheck_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=has_initial_state_tensor)
@pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("silu_activation", [False, True])
...@@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ...@@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref) assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, None, PAD_SLOT_ID))
@pytest.mark.parametrize("itype", @pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 4, 5]) @pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) @pytest.mark.parametrize("dim", [2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded # tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False]) @pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, @pytest.mark.parametrize("batch_size", [3])
seqlen, has_bias, def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
width, seqlen, has_bias,
silu_activation, itype): silu_activation, itype):
device = "cuda" device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
...@@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, ...@@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
# set seed # set seed
current_platform.seed_everything(0) current_platform.seed_everything(0)
batch_size = 3
padding = 5 if with_padding else 0 padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size total_entries = 10 * batch_size
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype) # x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
dtype=itype).transpose(1, 2)
x_ref = x.clone() x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to( conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
...@@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, ...@@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
], ],
dim=0) dim=0)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(total_entries, conv_state = torch.randn(total_entries,
dim,
width - 1, width - 1,
dim,
device=device, device=device,
dtype=itype) dtype=itype).transpose(1, 2)
conv_state_for_padding_test = conv_state.clone() conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim, width, device=device, dtype=itype) weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone() conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, out = causal_conv1d_update(x,
conv_state, conv_state,
weight, weight,
...@@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, ...@@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
activation=activation) activation=activation)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
assert torch.equal(conv_state[unused_states_bool], assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool]) conv_state_for_padding_test[unused_states_bool])
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
@pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize( @pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096])
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096]) @pytest.mark.parametrize('dim', [64, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize('with_padding', [True, False]) @pytest.mark.parametrize('with_padding', [True, False])
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, @pytest.mark.parametrize('batch', [4, 10])
silu_activation, itype): def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
has_bias, silu_activation, itype):
device = "cuda" device = "cuda"
torch.cuda.empty_cache() torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
...@@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, ...@@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
# set seed # set seed
current_platform.seed_everything(0) current_platform.seed_everything(0)
seqlens = [] seqlens = []
batch_size = 4 batch_size = batch
if seqlen < 10:
batch_size = 1
padding = 3 if with_padding else 0 padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1 nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append( seqlens.append(
torch.diff( torch.diff(
torch.cat( torch.cat(
...@@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, ...@@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0) dim=0)
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, x = rearrange(
dtype=itype)[:, 4096:4096 + dim, :] torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
"b s d -> b d s")[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype) weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone() x_ref = x.clone()
weight_ref = weight.clone() weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
final_states = torch.randn(total_entries, final_states = torch.randn(total_entries,
dim,
width - 1, width - 1,
dim,
device=x.device, device=x.device,
dtype=x.dtype) dtype=x.dtype).transpose(1, 2)
final_states_ref = final_states.clone() final_states_ref = final_states.clone()
has_initial_states = torch.randint(0, has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ), 2, (cumsum.shape[0] - 1, ),
...@@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, ...@@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
], ],
dim=-1) dim=-1)
out = causal_conv1d_fn(x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID)
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
final_states, activation, PAD_SLOT_ID)
out_ref = [] out_ref = []
out_ref_b = [] out_ref_b = []
...@@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, ...@@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0) out_ref_tensor = torch.cat(out_ref, dim=0)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
assert torch.allclose(final_states[state_indices], assert torch.allclose(final_states[state_indices],
final_states_ref[state_indices], final_states_ref[state_indices],
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
padded_state_indices, has_initial_states,
final_states, activation)
...@@ -6,11 +6,11 @@ import torch ...@@ -6,11 +6,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_query_start_loc_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined) mamba_chunk_scan_combined)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba_attn import (
_query_start_loc_to_chunk_indices_offsets)
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024
......
...@@ -963,17 +963,17 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, ...@@ -963,17 +963,17 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
out_dtype: torch.dtype, device: torch.device): out_dtype: torch.dtype, device: torch.device):
""" """
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes. the gemms for each combination based on the specified problem sizes.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
input and expert weights. input and expert weights.
- a_/b_scales: The blockscales in FP8-E4M3 precision - a_/b_scales: The blockscales in FP8-E4M3 precision
- expert_offsets/sf_offsets: Indices that mark at which token index - expert_offsets/sf_offsets: Indices that mark at which token index
each expert begins its computation. The number of tokens each expert begins its computation. The number of tokens
computed with expert E is expert_offsets[E + 1] - computed with expert E is expert_offsets[E + 1] -
expert_offsets[E] And the sf_size per expert is expert_offsets[E] And the sf_size per expert is
sf_offset[E+1] - sf_offset[E] sf_offset[E+1] - sf_offset[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation. MMs used in the fused MoE operation.
...@@ -1464,30 +1464,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int: ...@@ -1464,30 +1464,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int:
# mamba # mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool, pad_slot_id: int):
torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
query_start_loc, cache_indices,
has_initial_state, silu_activation,
pad_slot_id)
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int):
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation, cache_seqlens,
conv_state_indices, pad_slot_id)
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.placeholder_attn import ( from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata) PlaceholderAttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba_attn import (
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
@dataclass @dataclass
...@@ -21,6 +25,29 @@ class Mamba2Metadata: ...@@ -21,6 +25,29 @@ class Mamba2Metadata:
seq_idx: torch.Tensor seq_idx: torch.Tensor
chunk_indices: torch.Tensor chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor chunk_offsets: torch.Tensor
"""
With continuous batching layout of `x` in vLLM, to enable a Triton program
to handle a request in parallel, two supporting tensors are used
(batch_ptr, token_chunk_offset_ptr)
BLOCK_M = the # tokens to be handled by a Triton program
(can be customized for different hardware)
nums_dict:
tracks the data associated with a given value of BLOCK_M
BLOCK_M = #tokens handled by a Triton program
cu_seqlen: total tokens per batch
(used as flag to update other data at each new input)
batch_ptr: tracks batch-id handled by the Triton program
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
(Triton implementation of causal_conv1d handles parallelism in 3-axes
- feature-axis
- batch-axis
- sequence-axis)
"""
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
...@@ -38,45 +65,10 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: ...@@ -38,45 +65,10 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
f"Unsupported platform for Mamba2: {current_platform.device_type}") f"Unsupported platform for Mamba2: {current_platform.device_type}")
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N,
dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
def prepare_mamba2_metadata( def prepare_mamba2_metadata(
chunk_size: int, chunk_size: int,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
mamba2_metadata=None,
) -> Mamba2Metadata: ) -> Mamba2Metadata:
# compute number of prefill and decode requests # compute number of prefill and decode requests
...@@ -96,12 +88,12 @@ def prepare_mamba2_metadata( ...@@ -96,12 +88,12 @@ def prepare_mamba2_metadata(
attn_metadata_instances = get_platform_metadata_classes() attn_metadata_instances = get_platform_metadata_classes()
if (isinstance(attn_metadata, attn_metadata_instances) if (isinstance(attn_metadata, attn_metadata_instances)
and attn_metadata.context_lens_tensor is not None): and attn_metadata.context_lens_tensor is not None):
has_initial_states = \ # precompute flag to avoid device syncs later in mamba2 layer
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,] # forwards
# precompute flag to avoid device syncs in mamba2 layer forwards
# prep is only needed for mamba2 ssd prefill processing # prep is only needed for mamba2 ssd prefill processing
prep_initial_states = torch.any(has_initial_states).item() has_initial_states = attn_metadata.context_lens_tensor > 0
prep_initial_states = torch.any(
has_initial_states[:num_prefills]).item()
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx = torch.repeat_interleave(torch.arange( seq_idx = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device), num_prefills, dtype=torch.int32, device=query_start_loc.device),
...@@ -117,9 +109,78 @@ def prepare_mamba2_metadata( ...@@ -117,9 +109,78 @@ def prepare_mamba2_metadata(
_query_start_loc_to_chunk_indices_offsets( _query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens) query_start_loc, chunk_size, num_prefill_tokens)
if mamba2_metadata is not None:
mamba2_metadata.has_initial_states = has_initial_states
mamba2_metadata.prep_initial_states = prep_initial_states
mamba2_metadata.chunk_size = chunk_size
mamba2_metadata.seq_idx = seq_idx
mamba2_metadata.chunk_indices = chunk_indices
mamba2_metadata.chunk_offsets = chunk_offsets
# We use 1 reset flag:
# * mamba2_metadata.cu_seqlen is None
# update config specific to (each input)
# (become available at first layer, e.g. conv_weights)
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input
return mamba2_metadata
return Mamba2Metadata(has_initial_states=has_initial_states, return Mamba2Metadata(has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states, prep_initial_states=prep_initial_states,
chunk_size=chunk_size, chunk_size=chunk_size,
seq_idx=seq_idx, seq_idx=seq_idx,
chunk_indices=chunk_indices, chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets) chunk_offsets=chunk_offsets)
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
mamba2_metadata: Union[Mamba2Metadata,
Mamba2AttentionMetadata]):
"""
this is triggered upon handling a new input at the first layer
"""
dim, cu_seqlen = x.shape
mamba2_metadata.cu_seqlen = cu_seqlen
seqlens = np.diff(query_start_loc.to('cpu'))
nums_dict = {} # type: ignore
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if mamba2_metadata.batch_ptr is None:
# Update default value after class definition
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
mamba2_metadata.token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
else:
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
PAD_SLOT_ID)
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
mamba2_metadata.nums_dict = nums_dict
return mamba2_metadata
...@@ -159,7 +159,7 @@ class MambaMixer(CustomOp): ...@@ -159,7 +159,7 @@ class MambaMixer(CustomOp):
hidden_states = causal_conv1d_fn( hidden_states = causal_conv1d_fn(
hidden_states, hidden_states,
conv_weights, conv_weights,
self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation, activation=self.activation,
conv_states=mamba_cache_params.conv_state, conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0, has_initial_state=attn_metadata.context_lens_tensor > 0,
......
...@@ -17,7 +17,8 @@ from vllm.forward_context import get_forward_context ...@@ -17,7 +17,8 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
...@@ -161,9 +162,9 @@ def mamba_v2_sharded_weight_loader( ...@@ -161,9 +162,9 @@ def mamba_v2_sharded_weight_loader(
tp_size: int, tp_size: int,
tp_rank: int, tp_rank: int,
) -> LoaderFunction: ) -> LoaderFunction:
"""Create a weight loader for mamba v2. This ensures that the projections """Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also are correctly sharded so that they can be split into x, B, C. It also
ensures that all the groups corresponding to a head shard is placed ensures that all the groups corresponding to a head shard is placed
together with it. together with it.
""" """
...@@ -458,9 +459,11 @@ class MambaMixer2(CustomOp): ...@@ -458,9 +459,11 @@ class MambaMixer2(CustomOp):
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0] # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states has_initial_states_p = attn_metadata.has_initial_states
...@@ -531,6 +534,7 @@ class MambaMixer2(CustomOp): ...@@ -531,6 +534,7 @@ class MambaMixer2(CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill # NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input # Separate prefill and decode by splitting varlen input
# Split along token dimension # Split along token dimension
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split( hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C, hidden_states_B_C,
...@@ -579,8 +583,13 @@ class MambaMixer2(CustomOp): ...@@ -579,8 +583,13 @@ class MambaMixer2(CustomOp):
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions # - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor" # pointed to by "state_indices_tensor"
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(
x, attn_metadata.query_start_loc, mamba2_metadata)
hidden_states_B_C_p = causal_conv1d_fn( hidden_states_B_C_p = causal_conv1d_fn(
hidden_states_B_C_p.transpose(0, 1), x,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
activation=self.activation, activation=self.activation,
...@@ -590,8 +599,6 @@ class MambaMixer2(CustomOp): ...@@ -590,8 +599,6 @@ class MambaMixer2(CustomOp):
query_start_loc=query_start_loc_p).transpose( query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens] 0, 1)[:num_prefill_tokens]
# TODO: Why is this needed?
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
hidden_states_B_C_p) hidden_states_B_C_p)
...@@ -715,9 +722,10 @@ class MambaMixer2(CustomOp): ...@@ -715,9 +722,10 @@ class MambaMixer2(CustomOp):
# - heads and n_groups are TP-ed # - heads and n_groups are TP-ed
conv_dim = (self.intermediate_size + conv_dim = (self.intermediate_size +
2 * n_groups * self.ssm_state_size) 2 * n_groups * self.ssm_state_size)
# contiguous along 'dim' axis
conv_state_shape = ( conv_state_shape = (
divide(conv_dim, world_size),
self.conv_kernel_size - 1, self.conv_kernel_size - 1,
divide(conv_dim, world_size),
) )
# These are not TP-ed as they depend on A, dt_bias, D # These are not TP-ed as they depend on A, dt_bias, D
......
...@@ -36,10 +36,12 @@ class MambaCacheManager(ConstantSizeCache): ...@@ -36,10 +36,12 @@ class MambaCacheManager(ConstantSizeCache):
# Initialize parent class # Initialize parent class
super().__init__(max_batch_size) super().__init__(max_batch_size)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape, (conv_state_shape[1], conv_state_shape[0]),
dtype=dtype, dtype=dtype,
device="cuda") device="cuda").transpose(-1, -2)
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape, temporal_state_shape,
dtype=dtype, dtype=dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_query_start_loc_to_chunk_indices_offsets)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import MambaSpec from vllm.v1.kv_cache_interface import MambaSpec
...@@ -29,6 +28,42 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: ...@@ -29,6 +28,42 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
return chunk_sizes.pop() return chunk_sizes.pop()
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N,
dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust indices and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
class Mamba2AttentionBackend(AttentionBackend): class Mamba2AttentionBackend(AttentionBackend):
@staticmethod @staticmethod
...@@ -53,6 +88,10 @@ class Mamba2AttentionMetadata: ...@@ -53,6 +88,10 @@ class Mamba2AttentionMetadata:
chunk_offsets: torch.Tensor chunk_offsets: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,] state_indices_tensor: torch.Tensor # shape: [batch,]
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
class Mamba2AttentionMetadataBuilder( class Mamba2AttentionMetadataBuilder(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment