test_causal_conv1d.py 17.4 KB
Newer Older
1
2
3
4
5
6
from typing import Optional

import pytest
import torch
import torch.nn.functional as F

7
8
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops  # noqa: F401
9
from vllm.attention.backends.utils import PAD_SLOT_ID
10
11
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
    causal_conv1d_fn, causal_conv1d_update)
12
from vllm.utils import seed_everything
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59


def causal_conv1d_ref(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    initial_states: Optional[torch.Tensor] = None,
    return_final_states: bool = False,
    final_states_out: Optional[torch.Tensor] = None,
    activation: Optional[str] = "silu",
):
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1)

    out: (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    dtype_in = x.dtype
    x = x.to(weight.dtype)
    seqlen = x.shape[-1]
    dim, width = weight.shape
    if initial_states is None:
        out = F.conv1d(x,
                       weight.unsqueeze(1),
                       bias,
                       padding=width - 1,
                       groups=dim)
    else:
        x = torch.cat([initial_states, x], dim=-1)
        out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
    out = out[..., :seqlen]
    if return_final_states:
        final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
            dtype_in)  # (batch, dim, width - 1)
        if final_states_out is not None:
            final_states_out.copy_(final_states)
        else:
            final_states_out = final_states
    out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
    return (out, None) if not return_final_states else (out, final_states_out)


60
61
62
63
64
65
def causal_conv1d_update_ref(x,
                             conv_state,
                             weight,
                             bias=None,
                             activation=None,
                             cache_seqlens=None):
66
    """
67
68
    x: (batch, dim) or (batch, dim, seqlen)
    conv_state: (batch, dim, state_len), where state_len >= width - 1
69
70
    weight: (dim, width)
    bias: (dim,)
71
72
73
74
75
    cache_seqlens: (batch,), dtype int32.
        If not None, the conv_state is treated as a circular buffer.
        The conv_state will be updated by copying x to the 
        conv_state starting at the index
        @cache_seqlens % state_len before performing the convolution.
76

77
    out: (batch, dim) or (batch, dim, seqlen)
78
79
80
81
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    dtype_in = x.dtype
82
83
84
85
    unsqueeze = x.dim() == 2
    if unsqueeze:
        x = x.unsqueeze(-1)
    batch, dim, seqlen = x.shape
86
    width = weight.shape[1]
87
88
    state_len = conv_state.shape[-1]
    assert conv_state.shape == (batch, dim, state_len)
89
    assert weight.shape == (dim, width)
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    if cache_seqlens is None:
        x_new = torch.cat([conv_state, x], dim=-1).to(
            weight.dtype)  # (batch, dim, state_len + seqlen)
        conv_state.copy_(x_new[:, :, -state_len:])
    else:
        width_idx = torch.arange(
            -(width - 1), 0, dtype=torch.long,
            device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
        width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(
            -1, dim, -1)
        x_new = torch.cat([conv_state.gather(2, width_idx), x],
                          dim=-1).to(weight.dtype)
        copy_idx = torch.arange(
            seqlen, dtype=torch.long,
            device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
        copy_idx = torch.remainder(copy_idx,
                                   state_len).unsqueeze(1).expand(-1, dim, -1)
        conv_state.scatter_(2, copy_idx, x)
    out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
                   groups=dim)[:, :, -seqlen:]
    if unsqueeze:
        out = out.squeeze(-1)
112
113
114
    return (out if activation is None else F.silu(out)).to(dtype=dtype_in)


115
116
117
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
118
119
120
121
122
123
124
125
126
def causal_conv1d_opcheck_fn(x: torch.Tensor,
                             weight: torch.Tensor,
                             bias: Optional[torch.Tensor] = None,
                             cu_seq_len: Optional[torch.Tensor] = None,
                             cache_indices: Optional[torch.Tensor] = None,
                             has_initial_state: Optional[torch.Tensor] = None,
                             conv_states: Optional[torch.Tensor] = None,
                             activation: Optional[str] = "silu",
                             pad_slot_id: int = PAD_SLOT_ID):
127
128
129
130
131
132
133
134
135
136
137
138
139
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    seq_idx: (batch, seqlen)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1), to be written to
    activation: either None or "silu" or "swish"

    out: (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
140
    if x.stride(-1) != 1:
141
142
143
        x = x.contiguous()
    bias = bias.contiguous() if bias is not None else None

144
145
146
    opcheck(torch.ops._C.causal_conv1d_fwd,
            (x, weight, bias, conv_states, cu_seq_len, cache_indices,
             has_initial_state, activation in ["silu", "swish"], pad_slot_id))
147
148


149
150
151
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
152
@pytest.mark.parametrize("width", [4])
153
154
155
156
@pytest.mark.parametrize(
    'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
157
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
158
                       itype):
159
160
161
162
163
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
164
    seed_everything(0)
165
166
167
    x = torch.randn(batch, dim, seqlen, device=device,
                    dtype=itype).contiguous()

168
169
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
170
171
172
173
174
175
176
177
178
    initial_states = torch.randn(batch,
                                 dim,
                                 width - 1,
                                 device=device,
                                 dtype=itype)
    x_ref = x.clone()
    weight_ref = weight.clone()
    bias_ref = bias.clone() if bias is not None else None
    initial_states_ref = initial_states.clone(
179
180
    ) if initial_states is not None else None
    activation = None if not silu_activation else "silu"
181
182
183
184
185
186
187
188
    out = causal_conv1d_fn(x,
                           weight,
                           bias,
                           activation=activation,
                           conv_states=initial_states,
                           has_initial_state=torch.ones(batch,
                                                        dtype=torch.bool,
                                                        device=x.device))
189
190
191
192
193
    out_ref, final_states_ref = causal_conv1d_ref(
        x_ref,
        weight_ref,
        bias_ref,
        initial_states=initial_states_ref,
194
        return_final_states=True,
195
        activation=activation)
196
197
198
199
200
    assert initial_states is not None and final_states_ref is not None
    assert torch.allclose(initial_states,
                          final_states_ref,
                          rtol=rtol,
                          atol=atol)
201
202
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

203
204
205
206
207
208
209
210
    causal_conv1d_opcheck_fn(x,
                             weight,
                             bias,
                             activation=activation,
                             conv_states=initial_states,
                             has_initial_state=torch.ones(batch,
                                                          dtype=torch.bool,
                                                          device=x.device))
211
212
213
214
215


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
216
217
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
218
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
219
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
220
221
222
223
224
225
                              itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
226
    seed_everything(0)
227
    batch = 2
228
    x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
229
    x_ref = x.clone()
230
231
    conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)

232
233
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
234
235
236
237
238
239
240
    conv_state_ref = conv_state.detach().clone()
    activation = None if not silu_activation else "silu"
    out = causal_conv1d_update(x,
                               conv_state,
                               weight,
                               bias,
                               activation=activation)
241
    out_ref = causal_conv1d_update_ref(x_ref,
242
243
244
245
246
247
248
                                       conv_state_ref,
                                       weight,
                                       bias,
                                       activation=activation)

    assert torch.equal(conv_state, conv_state_ref)
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
249

250
251
252
    opcheck(torch.ops._C.causal_conv1d_update,
            (x, conv_state, weight, bias, activation
             in ["silu", "swish"], None, None, PAD_SLOT_ID))
253

254
255
256
257
258
259
260
261

@pytest.mark.parametrize("itype",
                         [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 4, 5])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
262
263
264
265
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
                                                seqlen, has_bias,
266
267
268
269
270
271
                                                silu_activation, itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2

272
    # set seed
273
    seed_everything(0)
274

275
276
277
278
    batch_size = 3
    padding = 5 if with_padding else 0
    padded_batch_size = batch_size + padding
    total_entries = 10 * batch_size
279

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
    x_ref = x.clone()

    conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
        dtype=torch.int32, device=device)
    unused_states_bool = torch.ones(total_entries,
                                    dtype=torch.bool,
                                    device=device)
    unused_states_bool[conv_state_indices] = False
    padded_state_indices = torch.concat([
        conv_state_indices,
        torch.as_tensor(
            [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
    ],
                                        dim=0)
295
296
    conv_state = torch.randn(total_entries,
                             dim,
297
                             width - 1,
298
299
                             device=device,
                             dtype=itype)
300
    conv_state_for_padding_test = conv_state.clone()
301

302
303
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
304
305
306
307
308
309
310
    conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
    activation = None if not silu_activation else "silu"
    out = causal_conv1d_update(x,
                               conv_state,
                               weight,
                               bias,
                               activation=activation,
311
312
313
                               conv_state_indices=padded_state_indices,
                               pad_slot_id=PAD_SLOT_ID)
    out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
314
315
316
317
318
319
                                       conv_state_ref,
                                       weight,
                                       bias,
                                       activation=activation)

    assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
320
321
322
    assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
    assert torch.equal(conv_state[unused_states_bool],
                       conv_state_for_padding_test[unused_states_bool])
323

324
325
326
    opcheck(torch.ops._C.causal_conv1d_update,
            (x, conv_state, weight, bias, activation
             in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
327
328
329
330
331
332


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
333
334
@pytest.mark.parametrize(
    'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
335
@pytest.mark.parametrize('dim', [64, 4096])
336
337
338
339
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize('with_padding', [True, False])
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
                              silu_activation, itype):
340
    device = "cuda"
341
    torch.cuda.empty_cache()
342
343
344
345
346
347
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
    seed_everything(0)
    seqlens = []
348
349
350
351
352
353
354
    batch_size = 4
    if seqlen < 10:
        batch_size = 1
    padding = 3 if with_padding else 0
    padded_batch_size = batch_size + padding
    nsplits = padded_batch_size - 1

355
356
357
358
359
360
361
362
363
    eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
    seqlens.append(
        torch.diff(
            torch.cat(
                [torch.tensor([-1]), eos_pos,
                 torch.tensor([seqlen - 1])])).tolist())
    assert sum(seqlens[-1]) == seqlen
    assert all(s > 0 for s in seqlens[-1])

364
    total_entries = batch_size * 10
365
366
367
    cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
    cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
                          dim=0)
368
    x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
369
370
371
372
373
374
375
                    dtype=itype)[:, 4096:4096 + dim, :]
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
    x_ref = x.clone()
    weight_ref = weight.clone()
    bias_ref = bias.clone() if bias is not None else None
    activation = None if not silu_activation else "silu"
376
    final_states = torch.randn(total_entries,
377
378
379
380
381
382
383
384
385
                               dim,
                               width - 1,
                               device=x.device,
                               dtype=x.dtype)
    final_states_ref = final_states.clone()
    has_initial_states = torch.randint(0,
                                       2, (cumsum.shape[0] - 1, ),
                                       dtype=torch.bool,
                                       device=x.device)
386
    state_indices = torch.randperm(total_entries,
387
                                   dtype=torch.int32,
388
389
390
391
392
393
394
395
                                   device=x.device)[:batch_size]
    padded_state_indices = torch.concat([
        state_indices,
        torch.as_tensor(
            [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
    ],
                                        dim=-1)

396
    out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
397
398
                           padded_state_indices, has_initial_states,
                           final_states, activation, PAD_SLOT_ID)
399
400
401
402
403
404
    out_ref = []
    out_ref_b = []

    splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
    for i in range(len(seqlens[0])):
        x_s = [v[i].unsqueeze(0) for v in splits][0]
405
406
        if padded_state_indices[i] == PAD_SLOT_ID:
            continue
407
408
409
410
411
412
413
        out_ref_b.append(
            causal_conv1d_ref(
                x_s,
                weight_ref,
                bias_ref,
                activation=activation,
                return_final_states=True,
414
415
416
417
                final_states_out=final_states_ref[
                    padded_state_indices[i]].unsqueeze(0),
                initial_states=final_states_ref[padded_state_indices[i]].
                unsqueeze(0) if has_initial_states[i] else None))
418
    out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
419
420
421
422
    out_ref_tensor = torch.cat(out_ref, dim=0)

    unpadded_out = out[:, :out_ref_tensor.shape[-1]]
    assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
423
    assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
424

425
    causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
426
427
                             padded_state_indices, has_initial_states,
                             final_states, activation)