Unverified Commit 47043eb6 authored by Tuan, Hoang-Trong's avatar Tuan, Hoang-Trong Committed by GitHub
Browse files

[Kernel] Triton implementation of causal-conv1d for Mamba-based models (#18218)


Signed-off-by: default avatarTuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: default avatarTuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: default avatarTyler Michael Smith <tysmith@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 31b96d1c
......@@ -232,7 +232,6 @@ endif()
set(VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/cache_kernels.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
......
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ConvParamsBase {
using index_t = uint32_t;
int batch, dim, seqlen, width;
int64_t pad_slot_id;
bool silu_activation;
index_t x_batch_stride;
index_t x_c_stride;
index_t x_l_stride;
index_t weight_c_stride;
index_t weight_width_stride;
index_t out_batch_stride;
index_t out_c_stride;
index_t out_l_stride;
int conv_state_len;
index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;
// Common data pointers.
void *__restrict__ x_ptr;
void *__restrict__ weight_ptr;
void *__restrict__ bias_ptr;
void *__restrict__ out_ptr;
void *__restrict__ conv_state_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ cache_indices_ptr;
int32_t *__restrict__ cache_seqlens;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t *__restrict__ conv_state_indices_ptr;
void *__restrict__ seq_idx_ptr;
// No __restrict__ since initial_states could be the same as final_states.
void * initial_states_ptr;
index_t initial_states_batch_stride;
index_t initial_states_l_stride;
index_t initial_states_c_stride;
void * final_states_ptr;
index_t final_states_batch_stride;
index_t final_states_l_stride;
index_t final_states_c_stride;
void * conv_states_ptr;
index_t conv_states_batch_stride;
index_t conv_states_l_stride;
index_t conv_states_c_stride;
};
#ifndef USE_ROCM
#include <cuda_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor_sync(uint32_t(-1), val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return std::max(ilist);
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return std::min(a, b);
}
#else
#include <hip/hip_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor(val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return *std::max_element(ilist.begin(), ilist.end());
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return a < b ? a : b;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES> struct BytesToType {};
template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
......@@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id);
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& weight,
const std::optional<at::Tensor>& bias_,
bool silu_activation,
const std::optional<at::Tensor>& cache_seqlens_,
const std::optional<at::Tensor>& conv_state_indices_,
int64_t pad_slot_id);
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const std::optional<at::Tensor>& bias_,
const std::optional<at::Tensor>& conv_states,
const std::optional<at::Tensor>& query_start_loc,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
bool silu_activation, int64_t pad_slot_id);
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
......
......@@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
......
......@@ -6,9 +6,8 @@ from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
......@@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
opcheck(torch.ops._C.causal_conv1d_fwd,
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("has_initial_state", [True, False])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
has_initial_state, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
current_platform.seed_everything(0)
x = torch.randn(batch, dim, seqlen, device=device,
dtype=itype).contiguous()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
if has_initial_state:
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
has_initial_state_tensor = torch.ones(batch,
dtype=torch.bool,
device=x.device)
else:
initial_states = None
has_initial_state_tensor = None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
initial_states_ref = initial_states.clone(
) if initial_states is not None else None
activation = None if not silu_activation else "silu"
out = causal_conv1d_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=has_initial_state_tensor)
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=True,
activation=activation)
if has_initial_state:
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
causal_conv1d_opcheck_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=has_initial_state_tensor)
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
......@@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, None, PAD_SLOT_ID))
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 4, 5])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
@pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
seqlen, has_bias,
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
width, seqlen, has_bias,
silu_activation, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
......@@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
# set seed
current_platform.seed_everything(0)
batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
# x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
dtype=itype).transpose(1, 2)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
......@@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(total_entries,
dim,
width - 1,
dim,
device=device,
dtype=itype)
dtype=itype).transpose(1, 2)
conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x,
conv_state,
weight,
......@@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
activation=activation)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool])
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize('with_padding', [True, False])
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
silu_activation, itype):
@pytest.mark.parametrize('batch', [4, 10])
def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
has_bias, silu_activation, itype):
device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
......@@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
# set seed
current_platform.seed_everything(0)
seqlens = []
batch_size = 4
if seqlen < 10:
batch_size = 1
batch_size = batch
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
......@@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
x = rearrange(
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
"b s d -> b d s")[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(total_entries,
dim,
width - 1,
dim,
device=x.device,
dtype=x.dtype)
dtype=x.dtype).transpose(1, 2)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
......@@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
out = causal_conv1d_fn(x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID)
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
final_states, activation, PAD_SLOT_ID)
out_ref = []
out_ref_b = []
......@@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
assert torch.allclose(final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol)
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
final_states, activation)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
......@@ -6,11 +6,11 @@ import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_query_start_loc_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba_attn import (
_query_start_loc_to_chunk_indices_offsets)
# Added by the IBM Team, 2024
......
......@@ -1464,30 +1464,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int:
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool, pad_slot_id: int):
torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
query_start_loc, cache_indices,
has_initial_state, silu_activation,
pad_slot_id)
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int):
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation, cache_seqlens,
conv_state_indices, pad_slot_id)
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba_attn import (
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
@dataclass
......@@ -21,6 +25,29 @@ class Mamba2Metadata:
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor
"""
With continuous batching layout of `x` in vLLM, to enable a Triton program
to handle a request in parallel, two supporting tensors are used
(batch_ptr, token_chunk_offset_ptr)
BLOCK_M = the # tokens to be handled by a Triton program
(can be customized for different hardware)
nums_dict:
tracks the data associated with a given value of BLOCK_M
BLOCK_M = #tokens handled by a Triton program
cu_seqlen: total tokens per batch
(used as flag to update other data at each new input)
batch_ptr: tracks batch-id handled by the Triton program
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
(Triton implementation of causal_conv1d handles parallelism in 3-axes
- feature-axis
- batch-axis
- sequence-axis)
"""
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
......@@ -38,45 +65,10 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
f"Unsupported platform for Mamba2: {current_platform.device_type}")
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N,
dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
def prepare_mamba2_metadata(
chunk_size: int,
attn_metadata: AttentionMetadata,
mamba2_metadata=None,
) -> Mamba2Metadata:
# compute number of prefill and decode requests
......@@ -96,12 +88,12 @@ def prepare_mamba2_metadata(
attn_metadata_instances = get_platform_metadata_classes()
if (isinstance(attn_metadata, attn_metadata_instances)
and attn_metadata.context_lens_tensor is not None):
has_initial_states = \
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
# precompute flag to avoid device syncs in mamba2 layer forwards
# precompute flag to avoid device syncs later in mamba2 layer
# forwards
# prep is only needed for mamba2 ssd prefill processing
prep_initial_states = torch.any(has_initial_states).item()
has_initial_states = attn_metadata.context_lens_tensor > 0
prep_initial_states = torch.any(
has_initial_states[:num_prefills]).item()
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device),
......@@ -117,9 +109,78 @@ def prepare_mamba2_metadata(
_query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens)
if mamba2_metadata is not None:
mamba2_metadata.has_initial_states = has_initial_states
mamba2_metadata.prep_initial_states = prep_initial_states
mamba2_metadata.chunk_size = chunk_size
mamba2_metadata.seq_idx = seq_idx
mamba2_metadata.chunk_indices = chunk_indices
mamba2_metadata.chunk_offsets = chunk_offsets
# We use 1 reset flag:
# * mamba2_metadata.cu_seqlen is None
# update config specific to (each input)
# (become available at first layer, e.g. conv_weights)
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input
return mamba2_metadata
return Mamba2Metadata(has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets)
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
mamba2_metadata: Union[Mamba2Metadata,
Mamba2AttentionMetadata]):
"""
this is triggered upon handling a new input at the first layer
"""
dim, cu_seqlen = x.shape
mamba2_metadata.cu_seqlen = cu_seqlen
seqlens = np.diff(query_start_loc.to('cpu'))
nums_dict = {} # type: ignore
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if mamba2_metadata.batch_ptr is None:
# Update default value after class definition
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
mamba2_metadata.token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
else:
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
PAD_SLOT_ID)
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
mamba2_metadata.nums_dict = nums_dict
return mamba2_metadata
......@@ -159,7 +159,7 @@ class MambaMixer(CustomOp):
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
bias=self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
......
......@@ -17,7 +17,8 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
......@@ -458,9 +459,11 @@ class MambaMixer2(CustomOp):
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
......@@ -531,6 +534,7 @@ class MambaMixer2(CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C,
......@@ -579,8 +583,13 @@ class MambaMixer2(CustomOp):
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(
x, attn_metadata.query_start_loc, mamba2_metadata)
hidden_states_B_C_p = causal_conv1d_fn(
hidden_states_B_C_p.transpose(0, 1),
x,
conv_weights,
self.conv1d.bias,
activation=self.activation,
......@@ -590,8 +599,6 @@ class MambaMixer2(CustomOp):
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
# TODO: Why is this needed?
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
hidden_states_B_C_p)
......@@ -715,9 +722,10 @@ class MambaMixer2(CustomOp):
# - heads and n_groups are TP-ed
conv_dim = (self.intermediate_size +
2 * n_groups * self.ssm_state_size)
# contiguous along 'dim' axis
conv_state_shape = (
divide(conv_dim, world_size),
self.conv_kernel_size - 1,
divide(conv_dim, world_size),
)
# These are not TP-ed as they depend on A, dt_bias, D
......
......@@ -36,10 +36,12 @@ class MambaCacheManager(ConstantSizeCache):
# Initialize parent class
super().__init__(max_batch_size)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape,
(conv_state_shape[1], conv_state_shape[0]),
dtype=dtype,
device="cuda")
device="cuda").transpose(-1, -2)
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_query_start_loc_to_chunk_indices_offsets)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import MambaSpec
......@@ -29,6 +28,42 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
return chunk_sizes.pop()
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N,
dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust indices and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
class Mamba2AttentionBackend(AttentionBackend):
@staticmethod
......@@ -53,6 +88,10 @@ class Mamba2AttentionMetadata:
chunk_offsets: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
class Mamba2AttentionMetadataBuilder(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment