"vllm/vscode:/vscode.git/clone" did not exist on "5c79b0d6484d7d4c5fe007c3c7ad04c72d3bc59e"
test_causal_conv1d.py 17.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
from typing import Optional

import pytest
import torch
import torch.nn.functional as F

10
11
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops  # noqa: F401
12
from vllm.attention.backends.utils import PAD_SLOT_ID
13
14
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
    causal_conv1d_fn, causal_conv1d_update)
15
from vllm.platforms import current_platform
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62


def causal_conv1d_ref(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    initial_states: Optional[torch.Tensor] = None,
    return_final_states: bool = False,
    final_states_out: Optional[torch.Tensor] = None,
    activation: Optional[str] = "silu",
):
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1)

    out: (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    dtype_in = x.dtype
    x = x.to(weight.dtype)
    seqlen = x.shape[-1]
    dim, width = weight.shape
    if initial_states is None:
        out = F.conv1d(x,
                       weight.unsqueeze(1),
                       bias,
                       padding=width - 1,
                       groups=dim)
    else:
        x = torch.cat([initial_states, x], dim=-1)
        out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
    out = out[..., :seqlen]
    if return_final_states:
        final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
            dtype_in)  # (batch, dim, width - 1)
        if final_states_out is not None:
            final_states_out.copy_(final_states)
        else:
            final_states_out = final_states
    out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
    return (out, None) if not return_final_states else (out, final_states_out)


63
64
65
66
67
68
def causal_conv1d_update_ref(x,
                             conv_state,
                             weight,
                             bias=None,
                             activation=None,
                             cache_seqlens=None):
69
    """
70
71
    x: (batch, dim) or (batch, dim, seqlen)
    conv_state: (batch, dim, state_len), where state_len >= width - 1
72
73
    weight: (dim, width)
    bias: (dim,)
74
75
    cache_seqlens: (batch,), dtype int32.
        If not None, the conv_state is treated as a circular buffer.
76
        The conv_state will be updated by copying x to the
77
78
        conv_state starting at the index
        @cache_seqlens % state_len before performing the convolution.
79

80
    out: (batch, dim) or (batch, dim, seqlen)
81
82
83
84
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    dtype_in = x.dtype
85
86
87
88
    unsqueeze = x.dim() == 2
    if unsqueeze:
        x = x.unsqueeze(-1)
    batch, dim, seqlen = x.shape
89
    width = weight.shape[1]
90
91
    state_len = conv_state.shape[-1]
    assert conv_state.shape == (batch, dim, state_len)
92
    assert weight.shape == (dim, width)
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    if cache_seqlens is None:
        x_new = torch.cat([conv_state, x], dim=-1).to(
            weight.dtype)  # (batch, dim, state_len + seqlen)
        conv_state.copy_(x_new[:, :, -state_len:])
    else:
        width_idx = torch.arange(
            -(width - 1), 0, dtype=torch.long,
            device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
        width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(
            -1, dim, -1)
        x_new = torch.cat([conv_state.gather(2, width_idx), x],
                          dim=-1).to(weight.dtype)
        copy_idx = torch.arange(
            seqlen, dtype=torch.long,
            device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
        copy_idx = torch.remainder(copy_idx,
                                   state_len).unsqueeze(1).expand(-1, dim, -1)
        conv_state.scatter_(2, copy_idx, x)
    out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
                   groups=dim)[:, :, -seqlen:]
    if unsqueeze:
        out = out.squeeze(-1)
115
116
117
    return (out if activation is None else F.silu(out)).to(dtype=dtype_in)


118
119
120
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
121
122
123
124
125
126
127
128
129
def causal_conv1d_opcheck_fn(x: torch.Tensor,
                             weight: torch.Tensor,
                             bias: Optional[torch.Tensor] = None,
                             cu_seq_len: Optional[torch.Tensor] = None,
                             cache_indices: Optional[torch.Tensor] = None,
                             has_initial_state: Optional[torch.Tensor] = None,
                             conv_states: Optional[torch.Tensor] = None,
                             activation: Optional[str] = "silu",
                             pad_slot_id: int = PAD_SLOT_ID):
130
131
132
133
134
135
136
137
138
139
140
141
142
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    seq_idx: (batch, seqlen)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1), to be written to
    activation: either None or "silu" or "swish"

    out: (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
143
    if x.stride(-1) != 1:
144
145
146
        x = x.contiguous()
    bias = bias.contiguous() if bias is not None else None

147
148
149
    opcheck(torch.ops._C.causal_conv1d_fwd,
            (x, weight, bias, conv_states, cu_seq_len, cache_indices,
             has_initial_state, activation in ["silu", "swish"], pad_slot_id))
150
151


152
153
154
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
155
@pytest.mark.parametrize("has_initial_state", [True, False])
156
@pytest.mark.parametrize("width", [4])
157
@pytest.mark.parametrize(
158
    'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
159
160
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
161
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
162
                       has_initial_state, itype):
163
164
165
166
167
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
168
    current_platform.seed_everything(0)
169
170
171
    x = torch.randn(batch, dim, seqlen, device=device,
                    dtype=itype).contiguous()

172
173
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
174
175
176
177
178
179
180
181
182
183
184
185
    if has_initial_state:
        initial_states = torch.randn(batch,
                                     dim,
                                     width - 1,
                                     device=device,
                                     dtype=itype)
        has_initial_state_tensor = torch.ones(batch,
                                              dtype=torch.bool,
                                              device=x.device)
    else:
        initial_states = None
        has_initial_state_tensor = None
186
187
188
189
    x_ref = x.clone()
    weight_ref = weight.clone()
    bias_ref = bias.clone() if bias is not None else None
    initial_states_ref = initial_states.clone(
190
191
    ) if initial_states is not None else None
    activation = None if not silu_activation else "silu"
192
193
194
195
196
    out = causal_conv1d_fn(x,
                           weight,
                           bias,
                           activation=activation,
                           conv_states=initial_states,
197
                           has_initial_state=has_initial_state_tensor)
198
199
200
201
202
    out_ref, final_states_ref = causal_conv1d_ref(
        x_ref,
        weight_ref,
        bias_ref,
        initial_states=initial_states_ref,
203
        return_final_states=True,
204
        activation=activation)
205
206
207
208
209
210
    if has_initial_state:
        assert initial_states is not None and final_states_ref is not None
        assert torch.allclose(initial_states,
                              final_states_ref,
                              rtol=rtol,
                              atol=atol)
211
212
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

213
214
215
216
217
    causal_conv1d_opcheck_fn(x,
                             weight,
                             bias,
                             activation=activation,
                             conv_states=initial_states,
218
                             has_initial_state=has_initial_state_tensor)
219
220
221
222
223


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
224
225
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
226
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
227
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
228
229
230
231
232
233
                              itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
234
    current_platform.seed_everything(0)
235
    batch = 2
236
    x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
237
    x_ref = x.clone()
238
239
    conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)

240
241
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
242
243
244
245
246
247
248
    conv_state_ref = conv_state.detach().clone()
    activation = None if not silu_activation else "silu"
    out = causal_conv1d_update(x,
                               conv_state,
                               weight,
                               bias,
                               activation=activation)
249
    out_ref = causal_conv1d_update_ref(x_ref,
250
251
252
253
254
255
256
                                       conv_state_ref,
                                       weight,
                                       bias,
                                       activation=activation)

    assert torch.equal(conv_state, conv_state_ref)
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
257

258
259
260
    opcheck(torch.ops._C.causal_conv1d_update,
            (x, conv_state, weight, bias, activation
             in ["silu", "swish"], None, None, PAD_SLOT_ID))
261

262
263
264
265
266
267
268
269

@pytest.mark.parametrize("itype",
                         [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 4, 5])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
270
271
272
273
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
                                                seqlen, has_bias,
274
275
276
277
278
279
                                                silu_activation, itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2

280
    # set seed
281
    current_platform.seed_everything(0)
282

283
284
285
286
    batch_size = 3
    padding = 5 if with_padding else 0
    padded_batch_size = batch_size + padding
    total_entries = 10 * batch_size
287

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
    x_ref = x.clone()

    conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
        dtype=torch.int32, device=device)
    unused_states_bool = torch.ones(total_entries,
                                    dtype=torch.bool,
                                    device=device)
    unused_states_bool[conv_state_indices] = False
    padded_state_indices = torch.concat([
        conv_state_indices,
        torch.as_tensor(
            [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
    ],
                                        dim=0)
303
304
    conv_state = torch.randn(total_entries,
                             dim,
305
                             width - 1,
306
307
                             device=device,
                             dtype=itype)
308
    conv_state_for_padding_test = conv_state.clone()
309

310
311
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
312
313
314
315
316
317
318
    conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
    activation = None if not silu_activation else "silu"
    out = causal_conv1d_update(x,
                               conv_state,
                               weight,
                               bias,
                               activation=activation,
319
320
321
                               conv_state_indices=padded_state_indices,
                               pad_slot_id=PAD_SLOT_ID)
    out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
322
323
324
325
326
327
                                       conv_state_ref,
                                       weight,
                                       bias,
                                       activation=activation)

    assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
328
329
330
    assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
    assert torch.equal(conv_state[unused_states_bool],
                       conv_state_for_padding_test[unused_states_bool])
331

332
333
334
    opcheck(torch.ops._C.causal_conv1d_update,
            (x, conv_state, weight, bias, activation
             in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
335
336
337
338
339
340


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
341
342
@pytest.mark.parametrize(
    'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
343
@pytest.mark.parametrize('dim', [64, 4096])
344
345
346
347
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize('with_padding', [True, False])
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
                              silu_activation, itype):
348
    device = "cuda"
349
    torch.cuda.empty_cache()
350
351
352
353
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
354
    current_platform.seed_everything(0)
355
    seqlens = []
356
357
358
359
360
361
362
    batch_size = 4
    if seqlen < 10:
        batch_size = 1
    padding = 3 if with_padding else 0
    padded_batch_size = batch_size + padding
    nsplits = padded_batch_size - 1

363
364
365
366
367
368
369
370
371
    eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
    seqlens.append(
        torch.diff(
            torch.cat(
                [torch.tensor([-1]), eos_pos,
                 torch.tensor([seqlen - 1])])).tolist())
    assert sum(seqlens[-1]) == seqlen
    assert all(s > 0 for s in seqlens[-1])

372
    total_entries = batch_size * 10
373
374
375
    cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
    cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
                          dim=0)
376
    x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
377
378
379
380
381
382
383
                    dtype=itype)[:, 4096:4096 + dim, :]
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
    x_ref = x.clone()
    weight_ref = weight.clone()
    bias_ref = bias.clone() if bias is not None else None
    activation = None if not silu_activation else "silu"
384
    final_states = torch.randn(total_entries,
385
386
387
388
389
390
391
392
393
                               dim,
                               width - 1,
                               device=x.device,
                               dtype=x.dtype)
    final_states_ref = final_states.clone()
    has_initial_states = torch.randint(0,
                                       2, (cumsum.shape[0] - 1, ),
                                       dtype=torch.bool,
                                       device=x.device)
394
    state_indices = torch.randperm(total_entries,
395
                                   dtype=torch.int32,
396
397
398
399
400
401
402
403
                                   device=x.device)[:batch_size]
    padded_state_indices = torch.concat([
        state_indices,
        torch.as_tensor(
            [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
    ],
                                        dim=-1)

404
    out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
405
406
                           padded_state_indices, has_initial_states,
                           final_states, activation, PAD_SLOT_ID)
407
408
409
410
411
412
    out_ref = []
    out_ref_b = []

    splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
    for i in range(len(seqlens[0])):
        x_s = [v[i].unsqueeze(0) for v in splits][0]
413
414
        if padded_state_indices[i] == PAD_SLOT_ID:
            continue
415
416
417
418
419
420
421
        out_ref_b.append(
            causal_conv1d_ref(
                x_s,
                weight_ref,
                bias_ref,
                activation=activation,
                return_final_states=True,
422
423
424
425
                final_states_out=final_states_ref[
                    padded_state_indices[i]].unsqueeze(0),
                initial_states=final_states_ref[padded_state_indices[i]].
                unsqueeze(0) if has_initial_states[i] else None))
426
    out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
427
428
429
430
    out_ref_tensor = torch.cat(out_ref, dim=0)

    unpadded_out = out[:, :out_ref_tensor.shape[-1]]
    assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
431
432
433
434
    assert torch.allclose(final_states[state_indices],
                          final_states_ref[state_indices],
                          rtol=rtol,
                          atol=atol)
435

436
    causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
437
438
                             padded_state_indices, has_initial_states,
                             final_states, activation)