"vscode:/vscode.git/clone" did not exist on "352c0c8a285414b11373e65fef095af7b07b94d8"
test_causal_conv1d.py 13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7

import pytest
import torch
import torch.nn.functional as F
8
from einops import rearrange
9

10
from vllm.attention.backends.utils import PAD_SLOT_ID
11
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
12
13
14
    causal_conv1d_fn,
    causal_conv1d_update,
)
15
from vllm.platforms import current_platform
16
17
18
19
20


def causal_conv1d_ref(
    x: torch.Tensor,
    weight: torch.Tensor,
21
22
    bias: torch.Tensor | None = None,
    initial_states: torch.Tensor | None = None,
23
    return_final_states: bool = False,
24
25
    final_states_out: torch.Tensor | None = None,
    activation: str | None = "silu",
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
):
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1)

    out: (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    dtype_in = x.dtype
    x = x.to(weight.dtype)
    seqlen = x.shape[-1]
    dim, width = weight.shape
    if initial_states is None:
43
        out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
44
45
46
47
48
49
    else:
        x = torch.cat([initial_states, x], dim=-1)
        out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
    out = out[..., :seqlen]
    if return_final_states:
        final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
50
51
            dtype_in
        )  # (batch, dim, width - 1)
52
53
54
55
56
57
58
59
        if final_states_out is not None:
            final_states_out.copy_(final_states)
        else:
            final_states_out = final_states
    out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
    return (out, None) if not return_final_states else (out, final_states_out)


60
61
62
def causal_conv1d_update_ref(
    x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
):
63
    """
64
65
    x: (batch, dim) or (batch, dim, seqlen)
    conv_state: (batch, dim, state_len), where state_len >= width - 1
66
67
    weight: (dim, width)
    bias: (dim,)
68
69
    cache_seqlens: (batch,), dtype int32.
        If not None, the conv_state is treated as a circular buffer.
70
        The conv_state will be updated by copying x to the
71
72
        conv_state starting at the index
        @cache_seqlens % state_len before performing the convolution.
73

74
    out: (batch, dim) or (batch, dim, seqlen)
75
76
77
78
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    dtype_in = x.dtype
79
80
81
82
    unsqueeze = x.dim() == 2
    if unsqueeze:
        x = x.unsqueeze(-1)
    batch, dim, seqlen = x.shape
83
    width = weight.shape[1]
84
85
    state_len = conv_state.shape[-1]
    assert conv_state.shape == (batch, dim, state_len)
86
    assert weight.shape == (dim, width)
87
88
    if cache_seqlens is None:
        x_new = torch.cat([conv_state, x], dim=-1).to(
89
90
            weight.dtype
        )  # (batch, dim, state_len + seqlen)
91
92
93
        conv_state.copy_(x_new[:, :, -state_len:])
    else:
        width_idx = torch.arange(
94
95
96
97
98
99
100
101
102
103
            -(width - 1), 0, dtype=torch.long, device=x.device
        ).unsqueeze(0) + cache_seqlens.unsqueeze(1)
        width_idx = (
            torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
        )
        x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
        copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(
            0
        ) + cache_seqlens.unsqueeze(1)
        copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
104
        conv_state.scatter_(2, copy_idx, x)
105
106
107
    out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[
        :, :, -seqlen:
    ]
108
109
    if unsqueeze:
        out = out.squeeze(-1)
110
111
112
    return (out if activation is None else F.silu(out)).to(dtype=dtype_in)


113
114
115
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
116
117
118
def causal_conv1d_opcheck_fn(
    x: torch.Tensor,
    weight: torch.Tensor,
119
120
121
122
123
124
    bias: torch.Tensor | None = None,
    cu_seq_len: torch.Tensor | None = None,
    cache_indices: torch.Tensor | None = None,
    has_initial_state: torch.Tensor | None = None,
    conv_states: torch.Tensor | None = None,
    activation: str | None = "silu",
125
126
    pad_slot_id: int = PAD_SLOT_ID,
):
127
128
129
130
131
132
133
134
135
136
137
138
139
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    seq_idx: (batch, seqlen)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1), to be written to
    activation: either None or "silu" or "swish"

    out: (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
140
    if x.stride(-1) != 1:
141
142
143
        x = x.contiguous()
    bias = bias.contiguous() if bias is not None else None

144
145
146
147

@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
148
149
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
150
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
151
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype):
152
153
154
155
156
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
157
    current_platform.seed_everything(0)
158
    batch = 2
159
    x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
160
    x_ref = x.clone()
161
162
    conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)

163
164
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
165
166
    conv_state_ref = conv_state.detach().clone()
    activation = None if not silu_activation else "silu"
167
168
169
170
171
172
173
174
175
176
177

    conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)

    out = causal_conv1d_update(
        x,
        conv_state,
        weight,
        bias,
        activation=activation,
        conv_state_indices=conv_state_indices,
    )
178
179
180
    out_ref = causal_conv1d_update_ref(
        x_ref, conv_state_ref, weight, bias, activation=activation
    )
181
182
183

    assert torch.equal(conv_state, conv_state_ref)
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
184
185


186
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
187
188
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
189
190
191
@pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
192
193
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
194
@pytest.mark.parametrize("batch_size", [3])
195
196
197
def test_causal_conv1d_update_with_batch_gather(
    batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype
):
198
199
200
201
202
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2

203
    # set seed
204
    current_platform.seed_everything(0)
205

206
207
    padding = 5 if with_padding else 0
    padded_batch_size = batch_size + padding
208
    # total_entries = number of cache line
209
    total_entries = 10 * batch_size
210

211
    # x will be (batch, dim, seqlen) with contiguous along dim-axis
212
213
214
    x = torch.randn(
        padded_batch_size, seqlen, dim, device=device, dtype=itype
    ).transpose(1, 2)
215

216
217
218
    x_ref = x.clone()

    conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
219
220
221
        dtype=torch.int32, device=device
    )
    unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
222
    unused_states_bool[conv_state_indices] = False
223
224
225
226
227
228
229
    padded_state_indices = torch.concat(
        [
            conv_state_indices,
            torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
        ],
        dim=0,
    )
230
231
232

    # conv_state will be (cache_lines, dim, state_len)
    # with contiguous along dim-axis
233
234
235
    conv_state = torch.randn(
        total_entries, width - 1, dim, device=device, dtype=itype
    ).transpose(1, 2)
236

237
    conv_state_for_padding_test = conv_state.clone()
238

239
240
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
241
242
    conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
    activation = None if not silu_activation else "silu"
243

244
245
246
247
248
249
250
251
252
253
254
255
    out = causal_conv1d_update(
        x,
        conv_state,
        weight,
        bias,
        activation=activation,
        conv_state_indices=padded_state_indices,
        pad_slot_id=PAD_SLOT_ID,
    )
    out_ref = causal_conv1d_update_ref(
        x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
    )
256
257

    assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
258
259
260
    assert torch.equal(
        conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]
    )
261
    assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
262
263
264
265
266
267


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
268
@pytest.mark.parametrize("seqlen", [8, 249, 4096])
269
270
271
272
273
274
@pytest.mark.parametrize("dim", [64, 4096])
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch", [4, 10])
def test_causal_conv1d_varlen(
    batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype
):
275
    device = "cuda"
276
    torch.cuda.empty_cache()
277
278
279
280
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
281
    current_platform.seed_everything(0)
282
    seqlens = []
283
    batch_size = batch
284
285
286
287
    padding = 3 if with_padding else 0
    padded_batch_size = batch_size + padding
    nsplits = padded_batch_size - 1

288
    eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
289

290
291
    seqlens.append(
        torch.diff(
292
293
294
            torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
        ).tolist()
    )
295
296
297
    assert sum(seqlens[-1]) == seqlen
    assert all(s > 0 for s in seqlens[-1])

298
    total_entries = batch_size * 10
299
    cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
300
    cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0)
301
302
    x = rearrange(
        torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
303
304
        "b s d -> b d s",
    )[:, 4096 : 4096 + dim, :]
305

306
    weight = torch.randn(dim, width, device=device, dtype=itype)
307

308
309
310
311
312
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
    x_ref = x.clone()
    weight_ref = weight.clone()
    bias_ref = bias.clone() if bias is not None else None
    activation = None if not silu_activation else "silu"
313
314
315
    final_states = torch.randn(
        total_entries, width - 1, dim, device=x.device, dtype=x.dtype
    ).transpose(1, 2)
316
    final_states_ref = final_states.clone()
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    has_initial_states = torch.randint(
        0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
    )
    state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
        :batch_size
    ]
    padded_state_indices = torch.concat(
        [
            state_indices,
            torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
        ],
        dim=-1,
    )
    out = causal_conv1d_fn(
        x.squeeze(0),
        weight,
        bias=bias,
        conv_states=final_states,
        query_start_loc=cumsum.cuda(),
        cache_indices=padded_state_indices,
        has_initial_state=has_initial_states,
        activation=activation,
        pad_slot_id=PAD_SLOT_ID,
    )
341

342
343
344
345
346
347
    out_ref = []
    out_ref_b = []

    splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
    for i in range(len(seqlens[0])):
        x_s = [v[i].unsqueeze(0) for v in splits][0]
348
349
        if padded_state_indices[i] == PAD_SLOT_ID:
            continue
350
351
352
353
354
355
356
        out_ref_b.append(
            causal_conv1d_ref(
                x_s,
                weight_ref,
                bias_ref,
                activation=activation,
                return_final_states=True,
357
358
359
360
361
362
                final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0),
                initial_states=final_states_ref[padded_state_indices[i]].unsqueeze(0)
                if has_initial_states[i]
                else None,
            )
        )
363
    out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
364
365
    out_ref_tensor = torch.cat(out_ref, dim=0)

366
367
368
369
370
371
372
    assert torch.allclose(
        final_states[state_indices],
        final_states_ref[state_indices],
        rtol=rtol,
        atol=atol,
    )
    unpadded_out = out[:, : out_ref_tensor.shape[-1]]
373
    assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)