Unverified Commit fdd9daaf authored by Mor Zusman's avatar Mor Zusman Committed by GitHub
Browse files

[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM (#7651)

parent 8c56e57d
...@@ -203,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -203,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_MakeAvailable(cutlass) FetchContent_MakeAvailable(cutlass)
list(APPEND VLLM_EXT_SRC list(APPEND VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
......
...@@ -42,9 +42,6 @@ COPY requirements-cuda.txt requirements-cuda.txt ...@@ -42,9 +42,6 @@ COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt python3 -m pip install -r requirements-cuda.txt
COPY requirements-mamba.txt requirements-mamba.txt
RUN python3 -m pip install packaging
RUN python3 -m pip install -r requirements-mamba.txt
# cuda arch list used by torch # cuda arch list used by torch
# can be useful for both `dev` and `test` # can be useful for both `dev` and `test`
...@@ -127,22 +124,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ...@@ -127,22 +124,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt python3 -m pip install -r requirements-dev.txt
#################### DEV IMAGE #################### #################### DEV IMAGE ####################
#################### MAMBA Build IMAGE ####################
FROM dev as mamba-builder
# max jobs used for build
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
WORKDIR /usr/src/mamba
COPY requirements-mamba.txt requirements-mamba.txt
# Download the wheel or build it if a pre-compiled release doesn't exist
RUN pip --verbose wheel -r requirements-mamba.txt \
--no-build-isolation --no-deps --no-cache-dir
#################### MAMBA Build IMAGE ####################
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################
# image with vLLM installed # image with vLLM installed
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
...@@ -179,10 +160,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist ...@@ -179,10 +160,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
--mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose python3 -m pip install dist/*.whl --verbose
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \ . /etc/environment && \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
......
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ConvParamsBase {
using index_t = uint32_t;
int batch, dim, seqlen, width;
bool silu_activation;
index_t x_batch_stride;
index_t x_c_stride;
index_t x_l_stride;
index_t weight_c_stride;
index_t weight_width_stride;
index_t out_batch_stride;
index_t out_c_stride;
index_t out_l_stride;
index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;
// Common data pointers.
void *__restrict__ x_ptr;
void *__restrict__ weight_ptr;
void *__restrict__ bias_ptr;
void *__restrict__ out_ptr;
void *__restrict__ conv_state_ptr;
void *__restrict__ seq_idx_ptr;
// No __restrict__ since initial_states could be the same as final_states.
void * initial_states_ptr;
index_t initial_states_batch_stride;
index_t initial_states_l_stride;
index_t initial_states_c_stride;
void * final_states_ptr;
index_t final_states_batch_stride;
index_t final_states_l_stride;
index_t final_states_c_stride;
};
#ifndef USE_ROCM
#include <cuda_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor_sync(uint32_t(-1), val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return std::max(ilist);
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return std::min(a, b);
}
#else
#include <hip/hip_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor(val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return *std::max_element(ilist.begin(), ilist.end());
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return a < b ? a : b;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES> struct BytesToType {};
template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
#pragma once
#ifndef USE_ROCM
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SSMParamsBase {
using index_t = uint32_t;
int batch, dim, seqlen, dstate, n_groups, n_chunks;
int dim_ngroups_ratio;
bool is_variable_B;
bool is_variable_C;
bool delta_softplus;
index_t A_d_stride;
index_t A_dstate_stride;
index_t B_batch_stride;
index_t B_d_stride;
index_t B_dstate_stride;
index_t B_group_stride;
index_t C_batch_stride;
index_t C_d_stride;
index_t C_dstate_stride;
index_t C_group_stride;
index_t u_batch_stride;
index_t u_d_stride;
index_t delta_batch_stride;
index_t delta_d_stride;
index_t z_batch_stride;
index_t z_d_stride;
index_t out_batch_stride;
index_t out_d_stride;
index_t out_z_batch_stride;
index_t out_z_d_stride;
// Common data pointers.
void *__restrict__ A_ptr;
void *__restrict__ B_ptr;
void *__restrict__ C_ptr;
void *__restrict__ D_ptr;
void *__restrict__ u_ptr;
void *__restrict__ delta_ptr;
void *__restrict__ delta_bias_ptr;
void *__restrict__ out_ptr;
void *__restrict__ x_ptr;
void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr;
void *__restrict__ index_ptr;
};
#ifndef USE_ROCM
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return std::max(ilist);
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return std::min(a, b);
}
#else
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return *std::max_element(ilist.begin(), ilist.end());
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return a < b ? a : b;
}
#endif
#define MAX_DSTATE 256
inline __device__ float2 operator+(const float2 & a, const float2 & b){
return {a.x + b.x, a.y + b.y};
}
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
return {a.x + b.x, a.y + b.y, a.z + b.z};
}
inline __device__ float4 operator+(const float4 & a, const float4 & b){
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES> struct BytesToType {};
template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename scalar_t, int N>
struct Converter{
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
}
};
template<int N>
struct Converter<at::Half, N>{
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
static_assert(N % 2 == 0);
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
#pragma unroll
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
}
};
#if __CUDA_ARCH__ >= 800
template<int N>
struct Converter<at::BFloat16, N>{
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
static_assert(N % 2 == 0);
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
#pragma unroll
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename scalar_t> struct SSMScanOp;
template<>
struct SSMScanOp<float> {
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
}
};
// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
scan_t running_prefix;
// Constructor
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ scan_t operator()(scan_t block_aggregate) {
scan_t old_prefix = running_prefix;
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
return old_prefix;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Ktraits>
inline __device__ void load_input(typename Ktraits::input_t *u,
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadT::TempStorage &smem_load,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
reinterpret_cast<vec_t*>(u),
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
#ifdef USE_ROCM
, Ktraits::kNThreads * Ktraits::kNLoads
#endif
);
} else {
typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
}
}
template<typename Ktraits>
inline __device__ void load_index(int *u,
int (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
reinterpret_cast<uint4*>(u),
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
);
} else {
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
}
}
template<typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
int seqlen) {
constexpr int kNItems = Ktraits::kNItems;
typename Ktraits::input_t B_vals_load[kNItems];
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
reinterpret_cast<vec_t*>(Bvar),
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
);
} else {
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
}
// #pragma unroll
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
}
template<typename Ktraits>
inline __device__ void store_output(typename Ktraits::input_t *out,
const float (&out_vals)[Ktraits::kNItems],
typename Ktraits::BlockStoreT::TempStorage &smem_store,
int seqlen) {
typename Ktraits::input_t write_vals[Ktraits::kNItems];
#pragma unroll
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
reinterpret_cast<vec_t*>(out),
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
);
} else {
typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
}
}
This diff is collapsed.
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
...@@ -195,6 +195,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -195,6 +195,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
std::vector<torch::Tensor> selective_scan_fwd(
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& C,
const c10::optional<torch::Tensor>& D_,
const c10::optional<torch::Tensor>& z_,
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
const c10::optional<torch::Tensor>& index_,
const c10::optional<torch::Tensor>& x);
at::Tensor causal_conv1d_update(const at::Tensor& x,
const at::Tensor& conv_state,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
bool silu_activation);
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& seq_idx_,
const c10::optional<at::Tensor>& initial_states_,
const c10::optional<at::Tensor>& final_states_out_,
bool silu_activation);
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = int64_t; using fptr_t = int64_t;
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
......
...@@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA, ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
&cutlass_scaled_mm_supports_fp8); &cutlass_scaled_mm_supports_fp8);
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor? x) -> Tensor[]");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor? final_states_out_,"
"bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif #endif
// Quantized GEMM for GPTQ. // Quantized GEMM for GPTQ.
......
# Mamba dependencies
mamba-ssm>=1.2.2
causal-conv1d>=1.2.0
...@@ -11,7 +11,7 @@ pytest-shard ...@@ -11,7 +11,7 @@ pytest-shard
# testing utils # testing utils
awscli awscli
einops # required for MPT and qwen-vl einops # required for MPT, qwen-vl and Mamba
httpx httpx
peft peft
requests requests
......
from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x,
weight.unsqueeze(1),
bias,
padding=width - 1,
groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_update_ref(x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
weight: (dim, width)
bias: (dim,)
out: (batch, dim)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
batch, dim = x.shape
width = weight.shape[1]
assert conv_state.shape == (batch, dim, width)
assert weight.shape == (dim, width)
conv_state.copy_(torch.roll(conv_state, shifts=-1,
dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
out = torch.sum(conv_state * weight, dim=-1) # (B D)
if bias is not None:
out += bias
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@pytest.mark.parametrize("return_final_states", [False, True])
@pytest.mark.parametrize("has_initial_states", [False, True])
@pytest.mark.parametrize("channel_last", [False, True])
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("seqlen", [128, 512, 4096])
@pytest.mark.parametrize('dim', [64, 4096 + 32])
@pytest.mark.parametrize('batch', [1, 2])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
itype, channel_last, has_initial_states,
return_final_states):
if not channel_last and (has_initial_states or return_final_states):
pytest.skip(
"Only channel_last support initial_states or return_final_states")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.random.manual_seed(0)
if not channel_last:
x = torch.randn(batch,
4096 + dim + 64,
seqlen,
device=device,
dtype=itype)[:, 4096:4096 + dim, :]
else:
x = rearrange(
torch.randn(batch,
seqlen,
4096 + dim + 64,
device=device,
dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s")
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
if has_initial_states:
initial_states = torch.randn(batch,
width - 1,
dim,
device=device,
dtype=itype).transpose(1, 2)
else:
initial_states = None
x_ref = x.detach().clone()
weight_ref = weight.detach().clone()
bias_ref = bias.detach().clone() if bias is not None else None
initial_states_ref = initial_states.detach().clone(
) if initial_states is not None else None
activation = None if not silu_activation else "silu"
out, final_states = causal_conv1d_fn(
x,
weight,
bias,
initial_states=initial_states,
return_final_states=return_final_states,
activation=activation)
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=return_final_states,
activation=activation)
if return_final_states:
assert final_states is not None and final_states_ref is not None
assert torch.allclose(final_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
if return_final_states:
out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
@pytest.mark.parametrize("batch", [1, 2])
def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.random.manual_seed(0)
batch = 2
x = torch.randn(batch, dim, device=device, dtype=itype)
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
weight = torch.randn(dim,
width,
device=device,
dtype=itype,
requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
else:
bias = None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x,
conv_state,
weight,
bias,
activation=activation)
out_ref = causal_conv1d_update_ref(x,
conv_state_ref,
weight,
bias,
activation=activation)
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
def selective_state_update_ref(state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
has_heads = state.dim() > 3
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
batch, nheads, dim, dstate = state.shape
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") *
A) # (batch, nheads, dim, dstate)
B = repeat(B, "b g n -> b (g h) n",
h=nheads // ngroups) # (batch, nheads, dstate)
C = repeat(C, "b g n -> b (g h) n",
h=nheads // ngroups) # (batch, nheads, dstate)
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
state.copy_(state * dA +
dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
if D is not None:
out += (x * D).to(out.dtype)
out = (out if z is None else out * F.silu(z)).to(x.dtype)
if not has_heads:
out = out.squeeze(1)
return out
def selective_scan_ref(u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
position_indices=None,
prev_state=None):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
prev_state: r(B D N), fp32
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
ys = []
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
for i in range(u.shape[2]):
if position_indices is not None and position_indices[0, i] == 0:
x = deltaB_u[:, :, i]
else:
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
@pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize("return_last_state", [True])
@pytest.mark.parametrize('has_delta_bias', [True])
@pytest.mark.parametrize('delta_softplus', [True])
@pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize('has_D', [True])
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
has_z, has_delta_bias, delta_softplus,
return_last_state, seqlen, itype, wtype, scan_chunks):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
rtolw, atolw = (1e-3, 1e-3)
if has_z: # If we have z, the errors on the weights seem higher
rtolw = max(rtolw, rtol)
atolw = max(atolw, atol)
# set seed
torch.random.manual_seed(0)
batch_size = 2
dim = 4
dstate = 8
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
if not is_variable_B:
B_shape = [dim, dstate]
elif varBC_groups == 1:
B_shape = [batch_size, dstate, seqlen]
else:
B_shape = [batch_size, varBC_groups, dstate, seqlen]
B = torch.randn(B_shape,
device=device,
dtype=wtype if not is_variable_B else itype)
if not is_variable_C:
C_shape = [dim, dstate]
elif varBC_groups == 1:
C_shape = [batch_size, dstate, seqlen]
else:
C_shape = [batch_size, varBC_groups, dstate, seqlen]
C = torch.randn(C_shape,
device=device,
dtype=wtype if not is_variable_C else itype)
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
z = torch.randn(batch_size, dim, seqlen, device=device,
dtype=itype) if has_z else None
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
) if has_delta_bias else None
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
delta = (0.5 *
torch.rand(batch_size, dim, seqlen, device=device, dtype=itype))
state = None
state_ref = None
out = None
out_ref = None
outs = []
for c in range(scan_chunks):
chunked_prompt_len = seqlen // scan_chunks
chunk_start = chunked_prompt_len * c
chunk_end = chunked_prompt_len * (c + 1)
if c == scan_chunks - 1:
chunk_end = seqlen
_B = B
if is_variable_B:
_B = B[..., chunk_start:chunk_end]
_C = C
if is_variable_B:
_C = C[..., chunk_start:chunk_end]
_z = z
if has_z:
assert z is not None
_z = z[..., chunk_start:chunk_end]
out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end],
delta[..., chunk_start:chunk_end],
A,
_B,
_C,
D,
z=_z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state,
prev_state=state if c > 0 else None)
outs.append(out)
if return_last_state:
state = rest[0]
if len(outs) > 1:
out = torch.cat(outs, dim=-1)
out_ref, *rest = selective_scan_ref(u,
delta,
A,
B,
C,
D,
z=z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state)
if return_last_state:
state_ref = rest[0]
assert out is not None and out_ref is not None
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
if return_last_state:
assert state is not None and state_ref is not None
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update(dim, dstate, has_z, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
if torch.version.hip:
atol *= 2
# set seed
torch.random.manual_seed(0)
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.detach().clone()
out = selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True)
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
...@@ -500,6 +500,36 @@ def ggml_mul_mat_a8( ...@@ -500,6 +500,36 @@ def ggml_mul_mat_a8(
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,
initial_states_, final_states_out_,
silu_activation)
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, bias_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation)
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
delta_bias_, delta_softplus, index_,
x)
# moe # moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor, block_size: int, sorted_token_ids: torch.Tensor,
......
# Copyright (c) 2024, Tri Dao.
from typing import Optional
import torch
from vllm import _custom_ops as ops
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out=None,
activation: str = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (initial_states is
None), "initial_states must be None if seq_idx is not None"
assert (not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (initial_states.stride(2) != 1
and initial_states.stride(1) != 1):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (final_states_out.stride(2) == 1
or final_states_out.stride(1) == 1)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(batch,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
else:
final_states_out = None
out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states,
final_states_out, activation
in ["silu", "swish"])
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_update(x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
weight: (dim, width)
bias: (dim,)
out: (batch, dim)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
activation_bool = activation in ["silu", "swish"]
return ops.causal_conv1d_update(x, conv_state, weight, bias,
activation_bool)
# Copyright (c) 2024, Tri Dao, Albert Gu.
import torch
import triton
import triton.language as tl
from packaging import version
from vllm import _custom_ops as ops
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
if TRITON3:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
return dt
else:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
return dt
@triton.heuristics(
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
@triton.jit
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr,
x_ptr,
dt_ptr,
dt_bias_ptr,
A_ptr,
B_ptr,
C_ptr,
D_ptr,
z_ptr,
out_ptr,
# Matrix dimensions
batch,
nheads,
dim,
dstate,
nheads_ngroups_ratio,
# Strides
stride_state_batch,
stride_state_head,
stride_state_dim,
stride_state_dstate,
stride_x_batch,
stride_x_head,
stride_x_dim,
stride_dt_batch,
stride_dt_head,
stride_dt_dim,
stride_dt_bias_head,
stride_dt_bias_dim,
stride_A_head,
stride_A_dim,
stride_A_dstate,
stride_B_batch,
stride_B_group,
stride_B_dstate,
stride_C_batch,
stride_C_group,
stride_C_dstate,
stride_D_head,
stride_D_dim,
stride_z_batch,
stride_z_head,
stride_z_dim,
stride_out_batch,
stride_out_head,
stride_out_dim,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
TIE_HDIM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
if HAS_DT_BIAS:
dt_bias_ptr += pid_h * stride_dt_bias_head
A_ptr += pid_h * stride_A_head
B_ptr += pid_b * stride_B_batch + (pid_h //
nheads_ngroups_ratio) * stride_B_group
C_ptr += pid_b * stride_C_batch + (pid_h //
nheads_ngroups_ratio) * stride_C_group
if HAS_Z:
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
offs_n[None, :] * stride_state_dstate)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
if HAS_D:
D_ptr += pid_h * stride_D_head
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +
offs_n[None, :] * stride_A_dstate)
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_D:
D_ptrs = D_ptr + offs_m * stride_D_dim
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
state = tl.load(state_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
tl.store(state_ptrs,
state,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
def selective_state_update(state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
has_heads = state.dim() > 3
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
batch, nheads, dim, dstate = state.shape
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
out = torch.empty_like(x)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
(0, 0, 0))
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else
((16, 4) if dstate <= 32 else
((8, 4) if dstate <= 64 else
((4, 4) if dstate <= 128 else ((4, 8))))))
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(
-1) == 0 and dt_bias.stride(-1) == 0
with torch.cuda.device(x.device.index):
_selective_scan_update_kernel[grid](
state,
x,
dt,
dt_bias,
A,
B,
C,
D,
z,
out,
batch,
nheads,
dim,
dstate,
nheads // ngroups,
state.stride(0),
state.stride(1),
state.stride(2),
state.stride(3),
x.stride(0),
x.stride(1),
x.stride(2),
dt.stride(0),
dt.stride(1),
dt.stride(2),
*(dt_bias.stride(0),
dt_bias.stride(1)) if dt_bias is not None else 0,
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(1),
B.stride(2),
C.stride(0),
C.stride(1),
C.stride(2),
*(D.stride(0), D.stride(1)) if D is not None else 0,
z_strides[0],
z_strides[1],
z_strides[2],
out.stride(0),
out.stride(1),
out.stride(2),
dt_softplus,
tie_hdim,
BLOCK_SIZE_M,
num_warps=num_warps,
)
if not has_heads:
out = out.squeeze(1)
return out
def selective_scan_fn(u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
position_indices=None,
prev_state=None):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
B = B.unsqueeze(1)
if C.dim() == 3:
C = C.unsqueeze(1)
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
x = torch.zeros((
u.shape[0],
u.shape[1],
n_chunks,
int(A.shape[1] * 2),
),
device=u.device,
dtype=torch.float32,
requires_grad=False)
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, position_indices, x)
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if z is None:
return out if not return_last_state else (out, last_state)
else:
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)
...@@ -4,9 +4,6 @@ from dataclasses import dataclass ...@@ -4,9 +4,6 @@ from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple
import torch import torch
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from torch import nn from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from transformers import JambaConfig from transformers import JambaConfig
...@@ -24,6 +21,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -24,6 +21,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
...@@ -161,7 +162,7 @@ class JambaMambaMixer(nn.Module): ...@@ -161,7 +162,7 @@ class JambaMambaMixer(nn.Module):
(self.conv_kernel_size - hidden_states.shape[-1], 0)) (self.conv_kernel_size - hidden_states.shape[-1], 0))
cache_params.conv_state.copy_(conv_states) cache_params.conv_state.copy_(conv_states)
hidden_states = causal_conv1d_fn( hidden_states, _ = causal_conv1d_fn(
hidden_states, hidden_states,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment