Unverified Commit f13a07b1 authored by Mor Zusman's avatar Mor Zusman Committed by GitHub
Browse files

[Kernel][Model] Varlen prefill + Prefill chunking support for mamba kernels and Jamba model (#8533)

parent 6c9ba48f
This diff is collapsed.
......@@ -24,6 +24,7 @@ struct ConvParamsBase {
index_t out_c_stride;
index_t out_l_stride;
int conv_state_len;
index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;
......@@ -35,6 +36,10 @@ struct ConvParamsBase {
void *__restrict__ out_ptr;
void *__restrict__ conv_state_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ cache_indices_ptr;
int32_t *__restrict__ cache_seqlens;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
......@@ -52,6 +57,11 @@ struct ConvParamsBase {
index_t final_states_batch_stride;
index_t final_states_l_stride;
index_t final_states_c_stride;
void * conv_states_ptr;
index_t conv_states_batch_stride;
index_t conv_states_l_stride;
index_t conv_states_c_stride;
};
......
......@@ -54,10 +54,14 @@ struct SSMParamsBase {
void *__restrict__ delta_ptr;
void *__restrict__ delta_bias_ptr;
void *__restrict__ out_ptr;
void *__restrict__ x_ptr;
void *__restrict__ ssm_states_ptr;
void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr;
void *__restrict__ index_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ cache_indices_ptr;
void *__restrict__ has_initial_state_ptr;
};
......@@ -201,7 +205,7 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadT::TempStorage &smem_load,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
......@@ -217,21 +221,6 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
}
}
template<typename Ktraits>
inline __device__ void load_index(int *u,
int (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
reinterpret_cast<uint4*>(u),
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
);
} else {
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
}
}
template<typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
......@@ -240,7 +229,7 @@ inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
int seqlen) {
constexpr int kNItems = Ktraits::kNItems;
typename Ktraits::input_t B_vals_load[kNItems];
if constexpr (Ktraits::kIsEvenLen) {
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
......@@ -263,7 +252,7 @@ inline __device__ void store_output(typename Ktraits::input_t *out,
typename Ktraits::input_t write_vals[Ktraits::kNItems];
#pragma unroll
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
if constexpr (Ktraits::kIsEvenLen) {
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
......
This diff is collapsed.
......@@ -215,25 +215,30 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
std::vector<torch::Tensor> selective_scan_fwd(
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& C,
const c10::optional<torch::Tensor>& D_,
const c10::optional<torch::Tensor>& z_,
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
const c10::optional<torch::Tensor>& index_,
const c10::optional<torch::Tensor>& x);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& C,
const c10::optional<torch::Tensor>& D_,
const c10::optional<torch::Tensor>& z_,
const c10::optional<torch::Tensor>& delta_bias_,
bool delta_softplus,
const c10::optional<torch::Tensor>& query_start_loc,
const c10::optional<torch::Tensor>& cache_indices,
const c10::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states);
at::Tensor causal_conv1d_update(
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias, bool silu_activation,
const c10::optional<at::Tensor>& conv_state_indices);
const c10::optional<at::Tensor>& bias_, bool silu_activation,
const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& conv_state_indices_);
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& seq_idx_,
const c10::optional<at::Tensor>& initial_states_,
const c10::optional<at::Tensor>& final_states_out_,
const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation);
#ifndef USE_ROCM
......
......@@ -273,26 +273,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor!? x) -> Tensor[]");
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices) -> Tensor");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor!? final_states_out_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif
......
......@@ -3,7 +3,6 @@ from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
......@@ -57,43 +56,72 @@ def causal_conv1d_ref(
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_update_ref(x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None):
def causal_conv1d_update_ref(x,
conv_state,
weight,
bias=None,
activation=None,
cache_seqlens=None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim)
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
batch, dim = x.shape
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
assert conv_state.shape == (batch, dim, width)
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
conv_state.copy_(torch.roll(conv_state, shifts=-1,
dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
out = torch.sum(conv_state * weight, dim=-1) # (B D)
if bias is not None:
out += bias
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(
weight.dtype) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(
-(width - 1), 0, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(
-1, dim, -1)
x_new = torch.cat([conv_state.gather(2, width_idx), x],
dim=-1).to(weight.dtype)
copy_idx = torch.arange(
seqlen, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx,
state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
groups=dim)[:, :, -seqlen:]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out=None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
......@@ -109,135 +137,93 @@ def causal_conv1d_opcheck_fn(
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (initial_states is
None), "initial_states must be None if seq_idx is not None"
assert (not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (initial_states.stride(2) != 1
and initial_states.stride(1) != 1):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (final_states_out.stride(2) == 1
or final_states_out.stride(1) == 1)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(batch,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
else:
final_states_out = None
opcheck(torch.ops._C.causal_conv1d_fwd,
(x, weight, bias, seq_idx, initial_states, final_states_out,
activation in ["silu", "swish"]))
opcheck(torch.ops._C.causal_conv1d_fwd, (
x,
weight,
bias,
conv_states,
cu_seq_len,
cache_indices,
has_initial_state,
activation in ["silu", "swish"],
))
@pytest.mark.parametrize("return_final_states", [False, True])
@pytest.mark.parametrize("has_initial_states", [False, True])
@pytest.mark.parametrize("channel_last", [False, True])
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("seqlen", [128, 512, 4096])
@pytest.mark.parametrize('dim', [64, 4096 + 32])
@pytest.mark.parametrize('batch', [1, 2])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
itype, channel_last, has_initial_states,
return_final_states):
if not channel_last and (has_initial_states or return_final_states):
pytest.skip(
"Only channel_last support initial_states or return_final_states")
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
if not channel_last:
x = torch.randn(batch,
4096 + dim + 64,
seqlen,
device=device,
dtype=itype)[:, 4096:4096 + dim, :]
else:
x = rearrange(
torch.randn(batch,
seqlen,
4096 + dim + 64,
device=device,
dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s")
x = torch.randn(batch, dim, seqlen, device=device,
dtype=itype).contiguous()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
if has_initial_states:
initial_states = torch.randn(batch,
width - 1,
dim,
device=device,
dtype=itype).transpose(1, 2)
else:
initial_states = None
x_ref = x.detach().clone()
weight_ref = weight.detach().clone()
bias_ref = bias.detach().clone() if bias is not None else None
initial_states_ref = initial_states.detach().clone(
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
initial_states_ref = initial_states.clone(
) if initial_states is not None else None
activation = None if not silu_activation else "silu"
out, final_states = causal_conv1d_fn(
x,
weight,
bias,
initial_states=initial_states,
return_final_states=return_final_states,
activation=activation)
out = causal_conv1d_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=return_final_states,
return_final_states=True,
activation=activation)
causal_conv1d_opcheck_fn(x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=return_final_states,
activation=activation)
if return_final_states:
assert final_states is not None and final_states_ref is not None
assert torch.allclose(final_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
if return_final_states:
out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
causal_conv1d_opcheck_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
@pytest.mark.parametrize("batch", [1, 2])
def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
......@@ -246,8 +232,9 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
# set seed
seed_everything(0)
batch = 2
x = torch.randn(batch, dim, device=device, dtype=itype)
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim,
width,
device=device,
......@@ -273,9 +260,15 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(
torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation in ["silu", "swish"], None))
opcheck(torch.ops._C.causal_conv1d_update, (
x,
conv_state,
weight,
bias,
activation in ["silu", "swish"],
None,
None,
))
@pytest.mark.parametrize("itype",
......@@ -292,16 +285,16 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.random.manual_seed(0)
# set )seed
seed_everything(0)
batch = 64
x = torch.randn(batch, dim, device=device, dtype=itype)
x = torch.randn(batch, dim, 1, device=device, dtype=itype)
total_entries = 10 * batch
conv_state = torch.randn(total_entries,
dim,
width,
width - 1,
device=device,
dtype=itype)
conv_state_indices = torch.randperm(total_entries)[:batch].to(
......@@ -332,3 +325,100 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update, (
x,
conv_state,
weight,
bias,
activation in ["silu", "swish"],
None,
conv_state_indices,
))
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize('seqlen',
[8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
@pytest.mark.parametrize('dim', [64, 4096])
def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
batch = 1
seqlens = []
nsplits = 3
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(nsplits + 1,
dim,
width - 1,
device=x.device,
dtype=x.dtype)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=x.device)
cache_indices = torch.randperm(cumsum.shape[0] - 1,
dtype=torch.int32,
device=x.device)
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
cache_indices, has_initial_states, final_states,
activation)
out_ref = []
out_ref_b = []
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[cache_indices[i]].unsqueeze(
0),
initial_states=final_states_ref[cache_indices[i]].unsqueeze(0)
if has_initial_states[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref = torch.cat(out_ref, dim=0)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print("Output state max diff"
f":{(final_states - final_states_ref).abs().max()}")
print("Output state mean diff"
f":{(final_states - final_states_ref).abs().mean()}")
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
cache_indices, has_initial_states, final_states,
activation)
......@@ -98,8 +98,8 @@ def selective_scan_ref(u,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
position_indices=None,
prev_state=None):
prev_state=None,
final_state_out=None):
"""
u: r(B D L)
delta: r(B D L)
......@@ -139,12 +139,8 @@ def selective_scan_ref(u,
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
for i in range(u.shape[2]):
if position_indices is not None and position_indices[0, i] == 0:
x = deltaB_u[:, :, i]
else:
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
......@@ -153,14 +149,17 @@ def selective_scan_ref(u,
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if final_state_out is None:
final_state_out = x
else:
final_state_out.copy_(x)
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
return out if not return_last_state else (out, final_state_out)
def selective_scan_opcheck_fn(u,
......@@ -172,9 +171,10 @@ def selective_scan_opcheck_fn(u,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
position_indices=None,
prev_state=None):
cu_seq_len=None,
cache_indices=None,
has_initial_state=None,
ssm_states=None):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
......@@ -190,36 +190,27 @@ def selective_scan_opcheck_fn(u,
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
if B.dim() == 3 and cu_seq_len is None:
B = B.unsqueeze(1)
if C.dim() == 3:
if B.dim() == 2 and cu_seq_len is not None:
B = B.unsqueeze(0)
if C.dim() == 3 and cu_seq_len is None:
C = C.unsqueeze(1)
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
x = torch.zeros((
u.shape[0],
u.shape[1],
n_chunks,
int(A.shape[1] * 2),
),
device=u.device,
dtype=torch.float32,
requires_grad=False)
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
if C.dim() == 2 and cu_seq_len is not None:
C = C.unsqueeze(0)
# Disable test_autograd_registration for now as it seems to trigger
# a bogus error.
opcheck(torch.ops._C.selective_scan_fwd,
(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
position_indices, x),
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len,
cache_indices, has_initial_state, ssm_states),
test_utils=["test_schema", "test_faketensor"])
@pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize('itype',
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize("return_last_state", [True])
@pytest.mark.parametrize('has_delta_bias', [True])
@pytest.mark.parametrize('delta_softplus', [True])
@pytest.mark.parametrize('has_z', [True])
......@@ -229,8 +220,8 @@ def selective_scan_opcheck_fn(u,
@pytest.mark.parametrize("is_variable_B", [True])
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
has_z, has_delta_bias, delta_softplus,
return_last_state, seqlen, itype, wtype, scan_chunks):
has_z, has_delta_bias, delta_softplus, seqlen, itype,
wtype, scan_chunks):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
......@@ -243,10 +234,11 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
atolw = max(atolw, atol)
# set seed
seed_everything(0)
batch_size = 2
batch_size = 1
dim = 4
dstate = 8
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
A_ref = A.clone()
if not is_variable_B:
B_shape = [dim, dstate]
elif varBC_groups == 1:
......@@ -256,6 +248,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
B = torch.randn(B_shape,
device=device,
dtype=wtype if not is_variable_B else itype)
B_ref = B.clone()
if not is_variable_C:
C_shape = [dim, dstate]
elif varBC_groups == 1:
......@@ -265,16 +258,25 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
C = torch.randn(C_shape,
device=device,
dtype=wtype if not is_variable_C else itype)
C_ref = C.clone()
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
D_ref = D.clone()
z = torch.randn(batch_size, dim, seqlen, device=device,
dtype=itype) if has_z else None
z_ref = z.clone() if has_z else None
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
) if has_delta_bias else None
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
u_ref = u.clone()
delta = (0.5 *
torch.rand(batch_size, dim, seqlen, device=device, dtype=itype))
state = None
state_ref = None
delta_ref = delta.clone()
state_shape = (batch_size, u.shape[1], int(A.shape[1]))
state = torch.randn(state_shape,
device=u.device,
dtype=itype,
requires_grad=False)
state_ref = state.clone()
out = None
out_ref = None
outs = []
......@@ -294,40 +296,40 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
if has_z:
assert z is not None
_z = z[..., chunk_start:chunk_end]
out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end],
delta[..., chunk_start:chunk_end],
A,
_B,
_C,
D,
z=_z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state,
prev_state=state if c > 0 else None)
out = selective_scan_fn(
u[..., chunk_start:chunk_end],
state,
delta[..., chunk_start:chunk_end],
A,
_B,
_C,
D,
z=_z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
has_initial_state=torch.ones(batch_size,
device=u.device,
dtype=torch.bool) if c > 0 else None)
outs.append(out)
if return_last_state:
state = rest[0]
if len(outs) > 1:
out = torch.cat(outs, dim=-1)
out_ref, *rest = selective_scan_ref(u,
delta,
A,
B,
C,
D,
z=z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state)
if return_last_state:
state_ref = rest[0]
out_ref, state_ref, *rest = selective_scan_ref(
u_ref,
delta_ref,
A_ref,
B_ref,
C_ref,
D_ref,
z=z_ref,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=True)
assert out is not None and out_ref is not None
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
if return_last_state:
assert state is not None and state_ref is not None
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert state is not None and state_ref is not None
assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u,
delta,
......@@ -335,10 +337,10 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
B,
C,
D,
z=z,
z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state)
ssm_states=state)
@pytest.mark.parametrize("itype",
......@@ -391,9 +393,131 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize('seqlen', [1, 128, 129, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize("return_last_state", [True])
@pytest.mark.parametrize('has_delta_bias', [True])
@pytest.mark.parametrize('delta_softplus', [True])
@pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize('has_D', [True])
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
has_D, has_z, has_delta_bias, delta_softplus,
return_last_state, seqlen, itype, wtype):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
rtolw, atolw = (1e-3, 1e-3)
if has_z: # If we have z, the errors on the weights seem higher
rtolw = max(rtolw, rtol)
atolw = max(atolw, atol)
# set seed
torch.random.manual_seed(0)
seqlens = []
nsplits = 3
if seqlen < 10:
nsplits = 0
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0).cuda()
dim = 4
dstate = 8
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
A_ref = A.clone()
B_shape = [varBC_groups, dstate, seqlen]
B = torch.randn(B_shape,
device=device,
dtype=wtype if not is_variable_B else itype)
B_ref = B.clone()
C_shape = [varBC_groups, dstate, seqlen]
C = torch.randn(C_shape,
device=device,
dtype=wtype if not is_variable_C else itype)
C_ref = C.clone()
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
D_ref = D.clone()
z = torch.randn(dim, seqlen, device=device, dtype=itype)
z_ref = z.clone()
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
) if has_delta_bias else None
u = torch.randn(dim, seqlen, device=device, dtype=itype)
u_ref = u.clone()
delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype))
delta_ref = delta.clone()
out = None
out_ref = None
prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1]))
prev_state = torch.randn(prev_state_shape,
device=u.device,
dtype=itype,
requires_grad=False)
prev_state_ref = prev_state.clone()
cache_indices = torch.randperm(cumsum.shape[0] - 1,
dtype=torch.int32,
device=u.device)
has_initial_state = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=u.device)
out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, cache_indices,
has_initial_state)
outs_ref = []
splits = [
torch.split(var, seqlens[0], dim=-1)
for var in (u_ref, delta_ref, B_ref, C_ref, z_ref)
]
for i in range(len(seqlens[0])):
u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits]
out_ref_s, _ = selective_scan_ref(
u_s,
delta_s,
A_ref,
B_s,
C_s,
D_ref,
z=z_s,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state,
prev_state=prev_state_ref[cache_indices[i]].unsqueeze(0)
if has_initial_state[i] else None,
final_state_out=prev_state_ref[cache_indices[i]].unsqueeze(0))
outs_ref.append(out_ref_s)
out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0]
print("Output diff max", (out - out_ref[0]).max())
print("Output diff mean", (out - out_ref[0]).mean())
print("Output state diff max", (prev_state - prev_state_ref).max())
print("Output state diff mean", (prev_state - prev_state_ref).mean())
assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, cache_indices,
has_initial_state, prev_state)
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
......@@ -405,7 +529,7 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
atol *= 2
# set seed
torch.random.manual_seed(0)
batch_size = 16
batch_size = 3
total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
......@@ -443,6 +567,11 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
dt_bias=dt_bias,
dt_softplus=True)
print("Output diff max", (out - out_ref[0]).max())
print("Output diff mean", (out - out_ref[0]).mean())
print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean",
(state[state_indices, :] - state_ref).mean())
assert torch.allclose(state[state_indices, :],
state_ref,
rtol=rtol,
......@@ -465,7 +594,7 @@ def test_selective_state_update_with_heads_with_batch_indices(
rtol, atol = 1e-1, 1e-1
# set seed
torch.random.manual_seed(0)
batch_size = 16
batch_size = 3
headdim = 64
nheads = dim // headdim
......
import pytest
from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size
from ...utils import check_outputs_equal
MODELS = ["ai21labs/Jamba-tiny-random"]
MODELS = ["ai21labs/Jamba-tiny-dev"]
# Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
# TODO: Fix this with trained model
@pytest.mark.skip()
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_models(
hf_runner,
vllm_runner,
......@@ -22,7 +20,14 @@ def test_models(
max_tokens: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
with hf_runner(
model,
dtype=dtype,
model_kwargs={
"use_mamba_kernels":
False, # mamba kernels are not installed so HF
# don't use them
}) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
......@@ -38,8 +43,8 @@ def test_models(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_batching(
vllm_runner,
example_prompts,
......@@ -65,6 +70,107 @@ def test_batching(
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_mamba_prefill_chunking_with_parallel_sampling(
hf_runner, vllm_runner, example_prompts, model: str, dtype: str,
max_tokens: int) -> None:
# Tests prefill chunking in conjunction with n>1, in this case,
# prefill is populated with decoding tokens and we test that it
# doesn't fail This test might fail if cache is not allocated
# correctly for n > 1 decoding steps inside a
# chunked prefill forward pass (where we have both prefills
# and decoding together )
sampling_params = SamplingParams(n=3,
temperature=1,
seed=0,
max_tokens=max_tokens)
with vllm_runner(
model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=30,
max_num_seqs=10 # forces prefill chunks with decoding
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int) -> None:
# numeric error during prefill chucking produces different generation
# compared to w/o prefill chunking for those examples, removed them for now
example_prompts.pop(7)
example_prompts.pop(2)
example_prompts.pop(1)
with hf_runner(
model,
dtype=dtype,
model_kwargs={
"use_mamba_kernels":
False, # mamba kernels are not installed so HF
# don't use them
}) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=5,
max_num_seqs=2) as vllm_model:
chunked = vllm_model.generate_greedy(example_prompts,
max_tokens=max_tokens)
check_outputs_equal(
outputs_0_lst=chunked,
outputs_1_lst=non_chunked,
name_0="chunked",
name_1="non_chunked",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [15])
def test_parallel_sampling(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
for_loop_outputs = []
for _ in range(10):
for_loop_outputs.append(
# using example_prompts index 1 instead of 0 since with 0 the
# logprobs get really close and the test doesn't pass
vllm_model.generate_greedy([example_prompts[1]], max_tokens)
[0])
sampling_params = SamplingParams(n=10,
temperature=0.001,
seed=0,
max_tokens=max_tokens)
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
sampling_params)
token_ids, texts = n_lt_1_outputs[0]
n_lt_1_outputs = [(token_id, text)
for token_id, text in zip(token_ids, texts)]
check_outputs_equal(
outputs_0_lst=n_lt_1_outputs,
outputs_1_lst=for_loop_outputs,
name_0="vllm_n_lt_1_outputs",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
......
......@@ -440,9 +440,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@torch.library.register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)
......@@ -450,22 +451,22 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def causal_conv1d_update_fake(
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
a = torch.empty_like(u)
if z_ is not None:
c = torch.empty_like(z_)
return [a, c]
else:
return [a]
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor],
delta_softplus: bool,
cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
ssm_states: Optional[torch.Tensor]) -> None:
return None
# cutlass
......@@ -761,37 +762,37 @@ def ggml_mul_mat_a8(
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,
initial_states_, final_states_out_,
silu_activation)
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
query_start_loc, cache_indices,
has_initial_state, silu_activation)
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
conv_state_indices: Optional[torch.Tensor],
) -> torch.Tensor:
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation,
silu_activation, cache_seqlens,
conv_state_indices)
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
delta_bias_, delta_softplus, index_,
x)
def selective_scan_fwd(
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor):
torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
delta_softplus, query_start_loc,
cache_indices, has_initial_state,
ssm_states)
# moe
......
......@@ -12,59 +12,44 @@ def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out=None,
activation: str = "silu",
query_start_loc: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (initial_states is
None), "initial_states must be None if seq_idx is not None"
assert (not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (initial_states.stride(2) != 1
and initial_states.stride(1) != 1):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (final_states_out.stride(2) == 1
or final_states_out.stride(1) == 1)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(batch,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
else:
final_states_out = None
out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states,
final_states_out, activation
out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
cache_indices, has_initial_state, activation
in ["silu", "swish"])
return (out, None) if not return_final_states else (out, final_states_out)
return out
def causal_conv1d_update(x: torch.Tensor,
......@@ -72,21 +57,33 @@ def causal_conv1d_update(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
out: (batch, dim)
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
activation_bool = activation in ["silu", "swish"]
return ops.causal_conv1d_update(x, conv_state, weight, bias,
activation_bool, conv_state_indices)
activation_val = activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
out = ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
cache_seqlens, conv_state_indices)
if unsqueeze:
out = out.squeeze(-1)
return out
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
from typing import Tuple
import torch
import triton
import triton.language as tl
......@@ -317,20 +319,50 @@ def selective_state_update(state,
return out
def selective_scan_fn(u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
position_indices=None,
prev_state=None):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
def selective_scan_fn(
u,
ssm_states,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
query_start_loc=None,
cache_indices=None,
has_initial_state=None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
C: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
D: (dim,)
z: (dim, total_length) for varlen or (batch, dim, seqlen)
dt_bias: (dim,) or (dim)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended with 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
A tensor with each cell is a correspondent
input and output ssm_state index
has_initial_state: (batch) bool
A tensor populated with ones and zeros,
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
there's no initial state
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
last_state has shape (batch, dim, dstate).
supports inplace replacement if ssm_state was provided
"""
if u.stride(-1) != 1:
u = u.contiguous()
......@@ -344,28 +376,20 @@ def selective_scan_fn(u,
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
if B.dim() == 3 and query_start_loc is None:
B = B.unsqueeze(1)
if C.dim() == 3:
if B.dim() == 2 and query_start_loc is not None:
B = B.unsqueeze(0)
if C.dim() == 3 and query_start_loc is None:
C = C.unsqueeze(1)
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
x = torch.zeros((
u.shape[0],
u.shape[1],
n_chunks,
int(A.shape[1] * 2),
),
device=u.device,
dtype=torch.float32,
requires_grad=False)
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, position_indices, x)
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if C.dim() == 2 and query_start_loc is not None:
C = C.unsqueeze(0)
ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
query_start_loc, cache_indices, has_initial_state,
ssm_states)
if z is None:
return out if not return_last_state else (out, last_state)
return delta # output written inplace to delta
else:
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)
return z # output written inplace to z
......@@ -138,42 +138,47 @@ class JambaMambaMixer(nn.Module):
self.c_layernorm = RMSNorm(self.ssm_state_size,
eps=config.rms_norm_eps)
def mamba_forward(self,
hidden_states: torch.Tensor,
cache_params: MambaCacheParams = None):
def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
ssm_state: torch.Tensor):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
hidden_states, gate = projected_states.chunk(2, dim=1)
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if cache_params is not None and not cache_params.is_prompt:
hidden_states = causal_conv1d_update(
hidden_states.squeeze(-1),
cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.unsqueeze(-1)
else:
if cache_params is not None:
conv_states = nn.functional.pad(
hidden_states,
(self.conv_kernel_size - hidden_states.shape[-1], 0))
cache_params.conv_state.copy_(conv_states)
hidden_states, _ = causal_conv1d_fn(
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
......@@ -184,72 +189,46 @@ class JambaMambaMixer(nn.Module):
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if cache_params is not None and not cache_params.is_prompt:
scan_outputs = selective_state_update(
cache_params.ssm_state,
hidden_states[..., 0],
discrete_time_step[..., 0],
self.A,
B[:, 0],
C[:, 0],
self.D,
gate[..., 0],
time_proj_bias,
dt_softplus=True,
).unsqueeze(-1)
else:
scan_outputs, ssm_state = selective_scan_fn(
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
discrete_time_step,
self.A,
B.transpose(1, 2),
C.transpose(1, 2),
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
return_last_state=True,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_state.copy_(ssm_state)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
):
if attn_metadata.prefill_metadata is not None:
offset = 0
for i, prompt_len in enumerate(
attn_metadata.prefill_metadata.seq_lens):
cache = MambaCacheParams(True,
conv_state=conv_state[i].unsqueeze(0),
ssm_state=ssm_state[i].unsqueeze(0))
hidden_states[offset:offset + prompt_len].copy_(
self.mamba_forward(hidden_states[offset:offset +
prompt_len].unsqueeze(0),
cache_params=cache)[0])
offset += prompt_len
else:
cache = MambaCacheParams(False,
conv_state=conv_state,
ssm_state=ssm_state)
hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
cache_params=cache)
hidden_states = hidden_states.squeeze(1)
return hidden_states
class JambaMoE(nn.Module):
......@@ -571,8 +550,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None:
assert not scheduler_config.chunked_prefill_enabled, \
"Jamba currently does not support chunked prefill"
assert not cache_config.enable_prefix_caching, \
"Jamba currently does not support prefix caching"
......@@ -616,18 +593,10 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids)
mamba_cache = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, batch_size, finished_requests_ids)
mamba_cache = self._release_finished_and_prepare_mamba_cache(
finished_requests_ids, request_ids_to_seq_ids)
else:
# CUDA graph capturing runs
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
......@@ -699,13 +668,15 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
batch_size: int, finished_requests_ids: List[str]):
finished_requests_ids: List[str]
) -> Tuple[torch.Tensor, torch.Tensor]:
running_indices = []
request_ids_to_seq_ids_flatten = [
(req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
batch_size = len(request_ids_to_seq_ids_flatten)
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
......@@ -769,22 +740,21 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
seq_ids2index.update({seq_id: to_index})
return
def _release_finished_and_prepare_mamba_cache(
self, finished_requests_ids,
request_ids_to_seq_ids) -> Tuple[torch.Tensor, torch.Tensor]:
self._release_mamba_cache(finished_requests_ids)
return self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
finished_requests_ids)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache into the CUDA graph input buffer
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0]
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)
self._release_finished_and_prepare_mamba_cache(
kwargs["finished_requests_ids"], kwargs["request_ids_to_seq_ids"])
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
......@@ -819,7 +789,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
hidden_size = self.config.hidden_size
conv_state_shape = (
self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_conv,
self.config.mamba_d_conv - 1,
)
temporal_state_shape = (
self.config.mamba_expand * self.config.hidden_size // world_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment