Commit 2eefe3d6 authored by luopl's avatar luopl
Browse files

add mamba

parent b7535e7c
Pipeline #1735 failed with stages
in 0 seconds
# Copyright (C) 2023, Tri Dao.
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref
# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
@pytest.mark.parametrize('wtype', [torch.float32])
# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize('itype', [torch.float32])
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
# @pytest.mark.parametrize('seqlen', [128])
# @pytest.mark.parametrize("return_last_state", [False, True])
@pytest.mark.parametrize("return_last_state", [True])
# @pytest.mark.parametrize('has_delta_bias', [False, True])
@pytest.mark.parametrize('has_delta_bias', [True])
# @pytest.mark.parametrize('delta_softplus', [False, True])
@pytest.mark.parametrize('delta_softplus', [True])
# @pytest.mark.parametrize('has_z', [False, True])
@pytest.mark.parametrize('has_z', [True])
# @pytest.mark.parametrize('has_D', [False, True])
@pytest.mark.parametrize('has_D', [True])
@pytest.mark.parametrize("varBC_groups", [1, 2])
# @pytest.mark.parametrize("varBC_groups", [1])
# @pytest.mark.parametrize("is_variable_C", [False, True])
@pytest.mark.parametrize("is_variable_C", [True])
# @pytest.mark.parametrize("is_variable_B", [False, True])
@pytest.mark.parametrize("is_variable_B", [True])
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias,
delta_softplus, return_last_state, seqlen, itype, wtype):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
rtolw, atolw = (1e-3, 1e-3)
if has_z: # If we have z, the errors on the weights seem higher
rtolw = max(rtolw, rtol)
atolw = max(atolw, atol)
# set seed
torch.random.manual_seed(0)
batch_size = 2
dim = 4
dstate = 8
is_complex = wtype == torch.complex64
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
if not is_variable_B:
B_shape = (dim, dstate)
elif varBC_groups == 1:
B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
else:
B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype,
requires_grad=True)
if not is_variable_C:
C_shape = (dim, dstate)
elif varBC_groups == 1:
C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
else:
C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype,
requires_grad=True)
if has_D:
D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
else:
D = None
if has_z:
z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
else:
z = None
if has_delta_bias:
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
else:
delta_bias = None
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_()
A_ref = A.detach().clone().requires_grad_()
B_ref = B.detach().clone().requires_grad_()
C_ref = C.detach().clone().requires_grad_()
D_ref = D.detach().clone().requires_grad_() if D is not None else None
z_ref = z.detach().clone().requires_grad_() if z is not None else None
u_ref = u.detach().clone().requires_grad_()
delta_ref = delta.detach().clone().requires_grad_()
delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
out, *rest = selective_scan_fn(
u, delta, A, B, C, D, z=z,
delta_bias=delta_bias, delta_softplus=delta_softplus,
return_last_state=return_last_state
)
if return_last_state:
state = rest[0]
out_ref, *rest = selective_scan_ref(
u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,
delta_bias=delta_bias_ref, delta_softplus=delta_softplus,
return_last_state=return_last_state
)
if return_last_state:
state_ref = rest[0]
# dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
# dt_u = delta * u
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
if return_last_state:
print(f'State max diff: {(state - state_ref).abs().max().item()}')
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
g = torch.randn_like(out)
out_ref.backward(g)
out.backward(g)
print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}')
print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}')
print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
if has_D:
print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
if has_z:
print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}')
if has_delta_bias:
print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
atol=atolw if not is_variable_B else atol)
assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
atol=atolw if not is_variable_C else atol)
if has_D:
assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
if has_z:
assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw)
if has_delta_bias:
assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
# @pytest.mark.parametrize('wtype', [torch.complex64])
# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize('itype', [torch.float32])
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
@pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize("is_variable_C", [False, True])
# @pytest.mark.parametrize("is_variable_C", [False])
@pytest.mark.parametrize("is_variable_B", [False, True])
# @pytest.mark.parametrize("is_variable_B", [True])
def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
device = 'cuda'
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
rtolw, atolw = (1e-3, 1e-3)
# If we have z, the errors on the weights seem higher
rtolw = max(rtolw, rtol)
atolw = max(atolw, atol)
# set seed
torch.random.manual_seed(0)
batch_size = 2
dim = 768
dstate = 8
dt_rank = 48
is_complex = wtype == torch.complex64
xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
* (1 if not is_complex else 2),
dim, device=device, dtype=itype, requires_grad=True)
delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
out_proj_bias = None
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
if not is_variable_B else None)
C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
if not is_variable_C else None)
D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
B_proj_bias = None
C_proj_bias = None
xz_ref = xz.detach().clone().requires_grad_()
conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
if out_proj_bias is not None else None)
A_ref = A.detach().clone().requires_grad_()
B_ref = B.detach().clone().requires_grad_() if B is not None else None
C_ref = C.detach().clone().requires_grad_() if C is not None else None
D_ref = D.detach().clone().requires_grad_()
delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
A_ref, B_ref, C_ref, D_ref,
delta_bias=delta_bias_ref, delta_softplus=True)
# dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
# dt_u = delta * u
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
g = torch.randn_like(out)
out_ref.backward(g)
out.backward(g)
print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}')
print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
if not is_variable_B:
print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
if not is_variable_C:
print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}')
print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}')
print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}')
print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}')
print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}')
# assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
# assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
# assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
# assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
# atol=atolw if not is_variable_B else atol)
# assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
# atol=atolw if not is_variable_C else atol)
# assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
# assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange, repeat
from mamba_ssm.ops.triton.layernorm_gated import layernorm_fn, rms_norm_ref
@pytest.mark.parametrize("norm_before_gate", [True, False])
# @pytest.mark.parametrize("norm_before_gate", [False])
@pytest.mark.parametrize("has_group", [False, True])
# @pytest.mark.parametrize("has_group", [False])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_z", [False, True])
# @pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("has_bias", [False, True])
# @pytest.mark.parametrize("has_bias", [False])
# @pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize("wtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("wtype", [torch.float32])
@pytest.mark.parametrize('d', [2048, 4096])
# @pytest.mark.parametrize('d', [4096])
def test_layer_norm_gated(d, dtype, wtype, has_bias, has_z, is_rms_norm, has_group, norm_before_gate):
if not has_z and not norm_before_gate:
pytest.skip()
if not norm_before_gate and not is_rms_norm: # Reference LN isn't implemented for this case yet
pytest.skip()
device = 'cuda'
rtol, atol = (1e-5, 1e-5) if dtype == torch.float32 else (1e-2, 8e-3)
group_size = None if not has_group else 64
# set seed
torch.random.manual_seed(0)
batch = 16
seqlen = 1024
x = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True)
if has_z:
z = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True)
else:
z = None
weight = torch.randn(d, dtype=wtype, device=device, requires_grad=True)
if has_bias:
bias = torch.randn(d, dtype=wtype, device=device, requires_grad=True)
else:
bias = None
x_ref = x.detach().clone().requires_grad_()
x_pt = x.detach().clone().requires_grad_()
z_ref = z.detach().clone().requires_grad_() if z is not None else None
z_pt = z.detach().clone().requires_grad_() if z is not None else None
weight_ref = weight.detach().clone().requires_grad_()
weight_pt = weight.detach().clone().requires_grad_()
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
out = layernorm_fn(x, weight, bias, z=z, eps=1e-5, group_size=group_size, norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm)
if not is_rms_norm:
if not has_group:
out_ref = F.layer_norm(x_ref.float(), (d,), weight=weight_ref.float(), bias=bias_ref.float() if bias_ref is not None else None, eps=1e-5)
out_pt = F.layer_norm(x_pt.to(wtype), (d,), weight=weight_pt, bias=bias_pt, eps=1e-5)
else:
out_ref = rearrange(F.layer_norm(rearrange(x_ref, "... (g d) -> ... g d", d=group_size).float(), (group_size,), eps=1e-5), "... g d -> ... (g d)") * weight_ref.float()
if has_bias:
out_ref = out_ref + bias_ref.float()
out_pt = rearrange(F.layer_norm(rearrange(x_pt, "... (g d) -> ... g d", d=group_size), (group_size,), eps=1e-5), "... g d -> ... (g d)") * weight_pt
if has_bias:
out_pt = out_pt + bias_pt
if has_z and norm_before_gate:
out_ref = out_ref * F.silu(z_ref.float())
out_pt = out_pt * F.silu(z_pt)
else:
out_ref = rms_norm_ref(x_ref, weight_ref, bias_ref, z=z_ref, eps=1e-5, group_size=group_size,
norm_before_gate=norm_before_gate)
out_pt = rms_norm_ref(x_pt, weight_pt, bias_pt, z=z_pt, eps=1e-5, group_size=group_size,
norm_before_gate=norm_before_gate, upcast=False)
print(f"Max diff = {(out - out_ref).abs().max().item()}")
print(f"Max diff Pytorch = {(out_pt - out_ref).abs().max().item()}")
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + atol
g = torch.randn_like(out)
out.backward(g)
out_ref.backward(g)
out_pt.backward(g)
print(f"Max dx diff = {(x.grad - x_ref.grad).abs().max().item()}")
print(f"Max dx diff Pytorch = {(x_pt.grad - x_ref.grad).abs().max().item()}")
if has_z:
print(f"Max dz diff = {(z.grad - z_ref.grad).abs().max().item()}")
print(f"Max dz diff Pytorch = {(z_pt.grad - z_ref.grad).abs().max().item()}")
print(f"Max dw diff = {(weight.grad - weight_ref.grad).abs().max().item()}")
print(f"Max dw diff Pytorch = {(weight_pt.grad - weight_ref.grad).abs().max().item()}")
if has_bias:
print(f"Max db diff = {(bias.grad - bias_ref.grad).abs().max().item()}")
print(f"Max db diff Pytorch = {(bias_pt.grad - bias_ref.grad).abs().max().item()}")
assert (x.grad - x_ref.grad).abs().max().item() <= 2 * (x_pt.grad - x_ref.grad).abs().max().item() + atol
if has_z:
assert (z.grad - z_ref.grad).abs().max().item() <= 2 * (z_pt.grad - z_ref.grad).abs().max().item() + atol
assert (weight.grad - weight_ref.grad).abs().max().item() <= 2 * (weight_pt.grad - weight_ref.grad).abs().max().item() + atol
if has_bias:
assert (bias.grad - bias_ref.grad).abs().max().item() <= 2 * (bias_pt.grad - bias_ref.grad).abs().max().item() + atol
# Copyright (C) 2023, Tri Dao.
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange, repeat
from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('itype', [torch.float16])
@pytest.mark.parametrize("has_z", [False, True])
# @pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
# @pytest.mark.parametrize("dstate", [16])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# @pytest.mark.parametrize("dim", [2048])
def test_selective_state_update(dim, dstate, has_z, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
if torch.version.hip:
atol *= 2
# set seed
torch.random.manual_seed(0)
batch_size = 2
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
if has_z:
z = torch.randn_like(x)
else:
z = None
state_ref = state.detach().clone()
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('itype', [torch.float16])
@pytest.mark.parametrize("has_z", [False, True])
# @pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize("tie_hdim", [False, True])
# @pytest.mark.parametrize('tie_hdim', [True])
@pytest.mark.parametrize("ngroups", [1, 2, 4])
# @pytest.mark.parametrize("ngroups", [2])
@pytest.mark.parametrize("dstate", [16, 32, 64])
# @pytest.mark.parametrize("dstate", [16])
@pytest.mark.parametrize("dim", [2048, 4096])
# @pytest.mark.parametrize("dim", [2048])
def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 1e-1
# set seed
torch.random.manual_seed(0)
batch_size = 2
headdim = 64
nheads = dim // headdim
state = torch.randn(batch_size, nheads, headdim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
if not tie_hdim:
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
D = torch.randn(nheads, headdim, device=device)
else:
dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim)
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate)
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
B = torch.randn(batch_size, ngroups, dstate, device=device)
C = torch.randn(batch_size, ngroups, dstate, device=device)
if has_z:
z = torch.randn_like(x)
else:
z = None
state_ref = state.detach().clone()
state_og = state.detach().clone()
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange, repeat
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen
from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref
from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd
from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_chunk_scan, ssd_chunk_scan_combined_ref, ssd_selective_scan
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined, mamba_split_conv1d_scan_ref
def detach_clone(*args):
return tuple([arg.detach().clone().requires_grad_() if arg is not None else None for arg in args])
@pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('ngroups', [1, 2, 8, "max"])
# @pytest.mark.parametrize('ngroups', [1])
@pytest.mark.parametrize('chunk_size', [64, 128])
# @pytest.mark.parametrize('chunk_size', [128])
def test_chunk_state_varlen(chunk_size, ngroups, dtype):
device = 'cuda'
rtol, atol = (1e-2, 3e-3)
# set seed
torch.random.manual_seed(chunk_size + (ngroups if ngroups != "max" else 64))
batch = 300
seqlens = torch.randint(1, 200, (batch,), device=device)
# batch = 3
# seqlens = torch.tensor([201, 56, 5], device=device)
cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
total_seqlen = seqlens.sum().item()
seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(seqlens)], dim=0).unsqueeze(0)
dim = 4096
# dim = 64
headdim = 64
# dim = 32
dstate = 32
assert dim % headdim == 0
nheads = dim // headdim
if ngroups == "max":
ngroups = nheads
assert nheads % ngroups == 0
B = torch.randn(total_seqlen, ngroups, dstate, dtype=dtype, device=device) / 5
x = torch.randn(total_seqlen, nheads, headdim, dtype=dtype, device=device)
A = -0.1 * (torch.rand(nheads, device=device))
dt = F.softplus(torch.randn(total_seqlen, nheads, device=device, dtype=torch.float32) - 4)
dA_cumsum, dt_rounded = _chunk_cumsum_fwd(dt.unsqueeze(0), A, chunk_size)
chunk_states = _chunk_state_fwd(B.unsqueeze(0), x.unsqueeze(0), dt_rounded, dA_cumsum, seq_idx=seq_idx)
chunk_states, _ = _state_passing_fwd(rearrange(chunk_states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
seq_idx=seq_idx, chunk_size=chunk_size)
chunk_states = rearrange(chunk_states, "... (p n) -> ... p n", n=dstate)
chunk_states = chunk_states.squeeze(0)
dA_cumsum = dA_cumsum.squeeze(0)
dt_rounded = dt_rounded.squeeze(0)
out = chunk_state_varlen(B, x, dt_rounded, dA_cumsum, cu_seqlens, chunk_states)
out_ref = []
for b in range(batch):
x_s = x[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
B_s = B[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
dt_s = dt[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
dA_cumsum_s, dt_rounded_s = _chunk_cumsum_fwd(dt_s, A, chunk_size)
states = chunk_state(B_s, x_s, dt_rounded_s, dA_cumsum_s)
_, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum_s[:, :, :, -1],
chunk_size=chunk_size)
final_states = rearrange(final_states, "... (p n) -> ... p n", n=dstate)
out_ref.append(final_states)
out_ref = torch.cat(out_ref, dim=0)
print(f"Max diff = {(out - out_ref).abs().max().item()}")
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
import torch
import torch.nn.functional as F
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.utils.generation import InferenceParams
import pytest
from einops import rearrange, repeat
def test_generation():
batch = 3
seqlen = 20
device = "cuda"
dtype = torch.float16
config = MambaConfig(
d_model=1024,
n_layer=4,
vocab_size=50277,
ssm_cfg=dict(layer="Mamba2"),
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=16,
)
torch.manual_seed(2357)
model = MambaLMHeadModel(config, device=device, dtype=dtype)
x = torch.randint(0, 1000, (batch, seqlen), device=device, dtype=torch.long)
out_ref = model(x).logits
prompt_len = seqlen // 2
out = model.generate(
input_ids = x[:, :prompt_len], max_length=seqlen, output_scores=True, return_dict_in_generate=True,
cg=True, # Can turn off CUDA graph for easier debugging
# instead of sampling, we take output tokens from x, to get logits for testing
# For actual generation, don't pass in teacher_outputs
teacher_outputs=x,
)
out_scores = torch.stack(out.scores, dim=1)
print(f"Max diff: {(out_scores - out_ref[:, prompt_len - 1: -1]).abs().max()}")
assert torch.allclose(out_scores, out_ref[:, prompt_len - 1: -1], rtol=1e-3, atol=1e-2)
def test_generation_varlen():
seqlens = [170, 65, 100]
genlen = 20
total_seqlen = sum(seqlens)
device = "cuda"
dtype = torch.float16
config = MambaConfig(
d_model=1024,
n_layer=4,
vocab_size=50277,
ssm_cfg=dict(layer="Mamba2"),
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=16,
)
torch.manual_seed(2357)
model = MambaLMHeadModel(config, device=device, dtype=dtype)
xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens]
# Reference 1: Forward pass with seq_idx
x = torch.cat(xs, dim=1)
seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device)
for i, ids in enumerate(xs)], dim=0).unsqueeze(0)
cu_seqlens = F.pad(torch.tensor(seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
out_ref = model(x, seq_idx=seq_idx).logits
# Only take the last @genlen logits of each sequence
out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1]
for i in range(len(seqlens))], dim=0)
# Reference 2: Generate the last @genlen tokens of each sequence in a for loop
out_loop = []
for input_ids in xs:
out = model.generate(
input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True,
return_dict_in_generate=True, cg=True, teacher_outputs=input_ids,
).scores
out_loop.append(torch.stack(out, dim=1))
out_loop = torch.cat(out_loop, dim=0)
print(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}")
# Varlen generation
input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1)
prompt_seqlens = [seqlen - genlen for seqlen in seqlens]
cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device)
for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0)
inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens))
scores, sequences = [], []
# Both seq_idx and cu_seqlens must be passed in for varlen generation
logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits
logits = rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d")
scores.append(logits)
# In practice we should sample. In this case we take from the teacher_output for testing
sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), "b -> b 1")
sequences.append(sampled_tokens)
for i in range(1, genlen):
inference_params.seqlen_offset += 1
logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits
scores.append(logits)
# In practice we should sample. In this case we take from the teacher_output for testing
sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), "b -> b 1")
sequences.append(sampled_tokens)
out_varlen = torch.cat(scores, dim=1)
print(f"Max diff: {(out_varlen - out_ref).abs().max()}")
assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment