test_causal_conv1d.py 17.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
from typing import Optional

import pytest
import torch
import torch.nn.functional as F

9
10
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops  # noqa: F401
11
from vllm.attention.backends.utils import PAD_SLOT_ID
12
13
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
    causal_conv1d_fn, causal_conv1d_update)
14
from vllm.platforms import current_platform
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61


def causal_conv1d_ref(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    initial_states: Optional[torch.Tensor] = None,
    return_final_states: bool = False,
    final_states_out: Optional[torch.Tensor] = None,
    activation: Optional[str] = "silu",
):
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1)

    out: (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    dtype_in = x.dtype
    x = x.to(weight.dtype)
    seqlen = x.shape[-1]
    dim, width = weight.shape
    if initial_states is None:
        out = F.conv1d(x,
                       weight.unsqueeze(1),
                       bias,
                       padding=width - 1,
                       groups=dim)
    else:
        x = torch.cat([initial_states, x], dim=-1)
        out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
    out = out[..., :seqlen]
    if return_final_states:
        final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
            dtype_in)  # (batch, dim, width - 1)
        if final_states_out is not None:
            final_states_out.copy_(final_states)
        else:
            final_states_out = final_states
    out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
    return (out, None) if not return_final_states else (out, final_states_out)


62
63
64
65
66
67
def causal_conv1d_update_ref(x,
                             conv_state,
                             weight,
                             bias=None,
                             activation=None,
                             cache_seqlens=None):
68
    """
69
70
    x: (batch, dim) or (batch, dim, seqlen)
    conv_state: (batch, dim, state_len), where state_len >= width - 1
71
72
    weight: (dim, width)
    bias: (dim,)
73
74
    cache_seqlens: (batch,), dtype int32.
        If not None, the conv_state is treated as a circular buffer.
75
        The conv_state will be updated by copying x to the
76
77
        conv_state starting at the index
        @cache_seqlens % state_len before performing the convolution.
78

79
    out: (batch, dim) or (batch, dim, seqlen)
80
81
82
83
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    dtype_in = x.dtype
84
85
86
87
    unsqueeze = x.dim() == 2
    if unsqueeze:
        x = x.unsqueeze(-1)
    batch, dim, seqlen = x.shape
88
    width = weight.shape[1]
89
90
    state_len = conv_state.shape[-1]
    assert conv_state.shape == (batch, dim, state_len)
91
    assert weight.shape == (dim, width)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    if cache_seqlens is None:
        x_new = torch.cat([conv_state, x], dim=-1).to(
            weight.dtype)  # (batch, dim, state_len + seqlen)
        conv_state.copy_(x_new[:, :, -state_len:])
    else:
        width_idx = torch.arange(
            -(width - 1), 0, dtype=torch.long,
            device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
        width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(
            -1, dim, -1)
        x_new = torch.cat([conv_state.gather(2, width_idx), x],
                          dim=-1).to(weight.dtype)
        copy_idx = torch.arange(
            seqlen, dtype=torch.long,
            device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
        copy_idx = torch.remainder(copy_idx,
                                   state_len).unsqueeze(1).expand(-1, dim, -1)
        conv_state.scatter_(2, copy_idx, x)
    out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
                   groups=dim)[:, :, -seqlen:]
    if unsqueeze:
        out = out.squeeze(-1)
114
115
116
    return (out if activation is None else F.silu(out)).to(dtype=dtype_in)


117
118
119
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
120
121
122
123
124
125
126
127
128
def causal_conv1d_opcheck_fn(x: torch.Tensor,
                             weight: torch.Tensor,
                             bias: Optional[torch.Tensor] = None,
                             cu_seq_len: Optional[torch.Tensor] = None,
                             cache_indices: Optional[torch.Tensor] = None,
                             has_initial_state: Optional[torch.Tensor] = None,
                             conv_states: Optional[torch.Tensor] = None,
                             activation: Optional[str] = "silu",
                             pad_slot_id: int = PAD_SLOT_ID):
129
130
131
132
133
134
135
136
137
138
139
140
141
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    seq_idx: (batch, seqlen)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1), to be written to
    activation: either None or "silu" or "swish"

    out: (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
142
    if x.stride(-1) != 1:
143
144
145
        x = x.contiguous()
    bias = bias.contiguous() if bias is not None else None

146
147
148
    opcheck(torch.ops._C.causal_conv1d_fwd,
            (x, weight, bias, conv_states, cu_seq_len, cache_indices,
             has_initial_state, activation in ["silu", "swish"], pad_slot_id))
149
150


151
152
153
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
154
@pytest.mark.parametrize("has_initial_state", [True, False])
155
@pytest.mark.parametrize("width", [4])
156
@pytest.mark.parametrize(
157
    'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
158
159
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
160
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
161
                       has_initial_state, itype):
162
163
164
165
166
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
167
    current_platform.seed_everything(0)
168
169
170
    x = torch.randn(batch, dim, seqlen, device=device,
                    dtype=itype).contiguous()

171
172
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
173
174
175
176
177
178
179
180
181
182
183
184
    if has_initial_state:
        initial_states = torch.randn(batch,
                                     dim,
                                     width - 1,
                                     device=device,
                                     dtype=itype)
        has_initial_state_tensor = torch.ones(batch,
                                              dtype=torch.bool,
                                              device=x.device)
    else:
        initial_states = None
        has_initial_state_tensor = None
185
186
187
188
    x_ref = x.clone()
    weight_ref = weight.clone()
    bias_ref = bias.clone() if bias is not None else None
    initial_states_ref = initial_states.clone(
189
190
    ) if initial_states is not None else None
    activation = None if not silu_activation else "silu"
191
192
193
194
195
    out = causal_conv1d_fn(x,
                           weight,
                           bias,
                           activation=activation,
                           conv_states=initial_states,
196
                           has_initial_state=has_initial_state_tensor)
197
198
199
200
201
    out_ref, final_states_ref = causal_conv1d_ref(
        x_ref,
        weight_ref,
        bias_ref,
        initial_states=initial_states_ref,
202
        return_final_states=True,
203
        activation=activation)
204
205
206
207
208
209
    if has_initial_state:
        assert initial_states is not None and final_states_ref is not None
        assert torch.allclose(initial_states,
                              final_states_ref,
                              rtol=rtol,
                              atol=atol)
210
211
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

212
213
214
215
216
    causal_conv1d_opcheck_fn(x,
                             weight,
                             bias,
                             activation=activation,
                             conv_states=initial_states,
217
                             has_initial_state=has_initial_state_tensor)
218
219
220
221
222


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
223
224
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
225
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
226
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
227
228
229
230
231
232
                              itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
233
    current_platform.seed_everything(0)
234
    batch = 2
235
    x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
236
    x_ref = x.clone()
237
238
    conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)

239
240
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
241
242
243
244
245
246
247
    conv_state_ref = conv_state.detach().clone()
    activation = None if not silu_activation else "silu"
    out = causal_conv1d_update(x,
                               conv_state,
                               weight,
                               bias,
                               activation=activation)
248
    out_ref = causal_conv1d_update_ref(x_ref,
249
250
251
252
253
254
255
                                       conv_state_ref,
                                       weight,
                                       bias,
                                       activation=activation)

    assert torch.equal(conv_state, conv_state_ref)
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
256

257
258
259
    opcheck(torch.ops._C.causal_conv1d_update,
            (x, conv_state, weight, bias, activation
             in ["silu", "swish"], None, None, PAD_SLOT_ID))
260

261
262
263
264
265
266
267
268

@pytest.mark.parametrize("itype",
                         [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 4, 5])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
269
270
271
272
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
                                                seqlen, has_bias,
273
274
275
276
277
278
                                                silu_activation, itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2

279
    # set seed
280
    current_platform.seed_everything(0)
281

282
283
284
285
    batch_size = 3
    padding = 5 if with_padding else 0
    padded_batch_size = batch_size + padding
    total_entries = 10 * batch_size
286

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
    x_ref = x.clone()

    conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
        dtype=torch.int32, device=device)
    unused_states_bool = torch.ones(total_entries,
                                    dtype=torch.bool,
                                    device=device)
    unused_states_bool[conv_state_indices] = False
    padded_state_indices = torch.concat([
        conv_state_indices,
        torch.as_tensor(
            [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
    ],
                                        dim=0)
302
303
    conv_state = torch.randn(total_entries,
                             dim,
304
                             width - 1,
305
306
                             device=device,
                             dtype=itype)
307
    conv_state_for_padding_test = conv_state.clone()
308

309
310
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
311
312
313
314
315
316
317
    conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
    activation = None if not silu_activation else "silu"
    out = causal_conv1d_update(x,
                               conv_state,
                               weight,
                               bias,
                               activation=activation,
318
319
320
                               conv_state_indices=padded_state_indices,
                               pad_slot_id=PAD_SLOT_ID)
    out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
321
322
323
324
325
326
                                       conv_state_ref,
                                       weight,
                                       bias,
                                       activation=activation)

    assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
327
328
329
    assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
    assert torch.equal(conv_state[unused_states_bool],
                       conv_state_for_padding_test[unused_states_bool])
330

331
332
333
    opcheck(torch.ops._C.causal_conv1d_update,
            (x, conv_state, weight, bias, activation
             in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
334
335
336
337
338
339


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
340
341
@pytest.mark.parametrize(
    'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
342
@pytest.mark.parametrize('dim', [64, 4096])
343
344
345
346
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize('with_padding', [True, False])
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
                              silu_activation, itype):
347
    device = "cuda"
348
    torch.cuda.empty_cache()
349
350
351
352
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
353
    current_platform.seed_everything(0)
354
    seqlens = []
355
356
357
358
359
360
361
    batch_size = 4
    if seqlen < 10:
        batch_size = 1
    padding = 3 if with_padding else 0
    padded_batch_size = batch_size + padding
    nsplits = padded_batch_size - 1

362
363
364
365
366
367
368
369
370
    eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
    seqlens.append(
        torch.diff(
            torch.cat(
                [torch.tensor([-1]), eos_pos,
                 torch.tensor([seqlen - 1])])).tolist())
    assert sum(seqlens[-1]) == seqlen
    assert all(s > 0 for s in seqlens[-1])

371
    total_entries = batch_size * 10
372
373
374
    cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
    cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
                          dim=0)
375
    x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
376
377
378
379
380
381
382
                    dtype=itype)[:, 4096:4096 + dim, :]
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
    x_ref = x.clone()
    weight_ref = weight.clone()
    bias_ref = bias.clone() if bias is not None else None
    activation = None if not silu_activation else "silu"
383
    final_states = torch.randn(total_entries,
384
385
386
387
388
389
390
391
392
                               dim,
                               width - 1,
                               device=x.device,
                               dtype=x.dtype)
    final_states_ref = final_states.clone()
    has_initial_states = torch.randint(0,
                                       2, (cumsum.shape[0] - 1, ),
                                       dtype=torch.bool,
                                       device=x.device)
393
    state_indices = torch.randperm(total_entries,
394
                                   dtype=torch.int32,
395
396
397
398
399
400
401
402
                                   device=x.device)[:batch_size]
    padded_state_indices = torch.concat([
        state_indices,
        torch.as_tensor(
            [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
    ],
                                        dim=-1)

403
    out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
404
405
                           padded_state_indices, has_initial_states,
                           final_states, activation, PAD_SLOT_ID)
406
407
408
409
410
411
    out_ref = []
    out_ref_b = []

    splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
    for i in range(len(seqlens[0])):
        x_s = [v[i].unsqueeze(0) for v in splits][0]
412
413
        if padded_state_indices[i] == PAD_SLOT_ID:
            continue
414
415
416
417
418
419
420
        out_ref_b.append(
            causal_conv1d_ref(
                x_s,
                weight_ref,
                bias_ref,
                activation=activation,
                return_final_states=True,
421
422
423
424
                final_states_out=final_states_ref[
                    padded_state_indices[i]].unsqueeze(0),
                initial_states=final_states_ref[padded_state_indices[i]].
                unsqueeze(0) if has_initial_states[i] else None))
425
    out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
426
427
428
429
    out_ref_tensor = torch.cat(out_ref, dim=0)

    unpadded_out = out[:, :out_ref_tensor.shape[-1]]
    assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
430
431
432
433
    assert torch.allclose(final_states[state_indices],
                          final_states_ref[state_indices],
                          rtol=rtol,
                          atol=atol)
434

435
    causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
436
437
                             padded_state_indices, has_initial_states,
                             final_states, activation)