test_causal_conv1d.py 14.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
from typing import Optional

import pytest
import torch
import torch.nn.functional as F
9
from einops import rearrange
10

11
from vllm.attention.backends.utils import PAD_SLOT_ID
12
13
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
    causal_conv1d_fn, causal_conv1d_update)
14
from vllm.platforms import current_platform
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61


def causal_conv1d_ref(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    initial_states: Optional[torch.Tensor] = None,
    return_final_states: bool = False,
    final_states_out: Optional[torch.Tensor] = None,
    activation: Optional[str] = "silu",
):
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1)

    out: (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    dtype_in = x.dtype
    x = x.to(weight.dtype)
    seqlen = x.shape[-1]
    dim, width = weight.shape
    if initial_states is None:
        out = F.conv1d(x,
                       weight.unsqueeze(1),
                       bias,
                       padding=width - 1,
                       groups=dim)
    else:
        x = torch.cat([initial_states, x], dim=-1)
        out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
    out = out[..., :seqlen]
    if return_final_states:
        final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
            dtype_in)  # (batch, dim, width - 1)
        if final_states_out is not None:
            final_states_out.copy_(final_states)
        else:
            final_states_out = final_states
    out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
    return (out, None) if not return_final_states else (out, final_states_out)


62
63
64
65
66
67
def causal_conv1d_update_ref(x,
                             conv_state,
                             weight,
                             bias=None,
                             activation=None,
                             cache_seqlens=None):
68
    """
69
70
    x: (batch, dim) or (batch, dim, seqlen)
    conv_state: (batch, dim, state_len), where state_len >= width - 1
71
72
    weight: (dim, width)
    bias: (dim,)
73
74
    cache_seqlens: (batch,), dtype int32.
        If not None, the conv_state is treated as a circular buffer.
75
        The conv_state will be updated by copying x to the
76
77
        conv_state starting at the index
        @cache_seqlens % state_len before performing the convolution.
78

79
    out: (batch, dim) or (batch, dim, seqlen)
80
81
82
83
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    dtype_in = x.dtype
84
85
86
87
    unsqueeze = x.dim() == 2
    if unsqueeze:
        x = x.unsqueeze(-1)
    batch, dim, seqlen = x.shape
88
    width = weight.shape[1]
89
90
    state_len = conv_state.shape[-1]
    assert conv_state.shape == (batch, dim, state_len)
91
    assert weight.shape == (dim, width)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    if cache_seqlens is None:
        x_new = torch.cat([conv_state, x], dim=-1).to(
            weight.dtype)  # (batch, dim, state_len + seqlen)
        conv_state.copy_(x_new[:, :, -state_len:])
    else:
        width_idx = torch.arange(
            -(width - 1), 0, dtype=torch.long,
            device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
        width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(
            -1, dim, -1)
        x_new = torch.cat([conv_state.gather(2, width_idx), x],
                          dim=-1).to(weight.dtype)
        copy_idx = torch.arange(
            seqlen, dtype=torch.long,
            device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
        copy_idx = torch.remainder(copy_idx,
                                   state_len).unsqueeze(1).expand(-1, dim, -1)
        conv_state.scatter_(2, copy_idx, x)
    out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
                   groups=dim)[:, :, -seqlen:]
    if unsqueeze:
        out = out.squeeze(-1)
114
115
116
    return (out if activation is None else F.silu(out)).to(dtype=dtype_in)


117
118
119
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
120
121
122
123
124
125
126
127
128
def causal_conv1d_opcheck_fn(x: torch.Tensor,
                             weight: torch.Tensor,
                             bias: Optional[torch.Tensor] = None,
                             cu_seq_len: Optional[torch.Tensor] = None,
                             cache_indices: Optional[torch.Tensor] = None,
                             has_initial_state: Optional[torch.Tensor] = None,
                             conv_states: Optional[torch.Tensor] = None,
                             activation: Optional[str] = "silu",
                             pad_slot_id: int = PAD_SLOT_ID):
129
130
131
132
133
134
135
136
137
138
139
140
141
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    seq_idx: (batch, seqlen)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1), to be written to
    activation: either None or "silu" or "swish"

    out: (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
142
    if x.stride(-1) != 1:
143
144
145
        x = x.contiguous()
    bias = bias.contiguous() if bias is not None else None

146
147
148
149

@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
150
151
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
152
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
153
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
154
155
156
157
158
159
                              itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
160
    current_platform.seed_everything(0)
161
    batch = 2
162
    x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
163
    x_ref = x.clone()
164
165
    conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)

166
167
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
168
169
170
171
172
173
174
    conv_state_ref = conv_state.detach().clone()
    activation = None if not silu_activation else "silu"
    out = causal_conv1d_update(x,
                               conv_state,
                               weight,
                               bias,
                               activation=activation)
175
    out_ref = causal_conv1d_update_ref(x_ref,
176
177
178
179
180
181
182
                                       conv_state_ref,
                                       weight,
                                       bias,
                                       activation=activation)

    assert torch.equal(conv_state, conv_state_ref)
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
183
184
185
186
187
188


@pytest.mark.parametrize("itype",
                         [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
189
190
191
@pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
192
193
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
194
195
196
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
                                                width, seqlen, has_bias,
197
198
199
200
201
202
                                                silu_activation, itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2

203
    # set seed
204
    current_platform.seed_everything(0)
205

206
207
    padding = 5 if with_padding else 0
    padded_batch_size = batch_size + padding
208
    # total_entries = number of cache line
209
    total_entries = 10 * batch_size
210

211
212
213
214
    # x will be (batch, dim, seqlen) with contiguous along dim-axis
    x = torch.randn(padded_batch_size, seqlen, dim, device=device,
                    dtype=itype).transpose(1, 2)

215
216
217
218
219
220
221
222
223
224
225
226
227
228
    x_ref = x.clone()

    conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
        dtype=torch.int32, device=device)
    unused_states_bool = torch.ones(total_entries,
                                    dtype=torch.bool,
                                    device=device)
    unused_states_bool[conv_state_indices] = False
    padded_state_indices = torch.concat([
        conv_state_indices,
        torch.as_tensor(
            [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
    ],
                                        dim=0)
229
230
231

    # conv_state will be (cache_lines, dim, state_len)
    # with contiguous along dim-axis
232
    conv_state = torch.randn(total_entries,
233
                             width - 1,
234
                             dim,
235
                             device=device,
236
237
                             dtype=itype).transpose(1, 2)

238
    conv_state_for_padding_test = conv_state.clone()
239

240
241
    weight = torch.randn(dim, width, device=device, dtype=itype)
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
242
243
    conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
    activation = None if not silu_activation else "silu"
244

245
246
247
248
249
    out = causal_conv1d_update(x,
                               conv_state,
                               weight,
                               bias,
                               activation=activation,
250
251
252
                               conv_state_indices=padded_state_indices,
                               pad_slot_id=PAD_SLOT_ID)
    out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
253
254
255
256
257
258
                                       conv_state_ref,
                                       weight,
                                       bias,
                                       activation=activation)

    assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
259
260
    assert torch.equal(conv_state[unused_states_bool],
                       conv_state_for_padding_test[unused_states_bool])
261
    assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
262
263
264
265
266
267


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
268
@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096])
269
@pytest.mark.parametrize('dim', [64, 4096])
270
@pytest.mark.parametrize('with_padding', [True, False])
271
272
273
@pytest.mark.parametrize('batch', [4, 10])
def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
                              has_bias, silu_activation, itype):
274
    device = "cuda"
275
    torch.cuda.empty_cache()
276
277
278
279
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
    # set seed
280
    current_platform.seed_everything(0)
281
    seqlens = []
282
    batch_size = batch
283
284
285
286
    padding = 3 if with_padding else 0
    padded_batch_size = batch_size + padding
    nsplits = padded_batch_size - 1

287
    eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
288

289
290
291
292
293
294
295
296
    seqlens.append(
        torch.diff(
            torch.cat(
                [torch.tensor([-1]), eos_pos,
                 torch.tensor([seqlen - 1])])).tolist())
    assert sum(seqlens[-1]) == seqlen
    assert all(s > 0 for s in seqlens[-1])

297
    total_entries = batch_size * 10
298
299
300
    cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
    cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
                          dim=0)
301
302
303
304
    x = rearrange(
        torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
        "b s d -> b d s")[:, 4096:4096 + dim, :]

305
    weight = torch.randn(dim, width, device=device, dtype=itype)
306

307
308
309
310
311
    bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
    x_ref = x.clone()
    weight_ref = weight.clone()
    bias_ref = bias.clone() if bias is not None else None
    activation = None if not silu_activation else "silu"
312
    final_states = torch.randn(total_entries,
313
                               width - 1,
314
                               dim,
315
                               device=x.device,
316
                               dtype=x.dtype).transpose(1, 2)
317
318
319
320
321
    final_states_ref = final_states.clone()
    has_initial_states = torch.randint(0,
                                       2, (cumsum.shape[0] - 1, ),
                                       dtype=torch.bool,
                                       device=x.device)
322
    state_indices = torch.randperm(total_entries,
323
                                   dtype=torch.int32,
324
325
326
327
328
329
330
                                   device=x.device)[:batch_size]
    padded_state_indices = torch.concat([
        state_indices,
        torch.as_tensor(
            [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
    ],
                                        dim=-1)
331
332
333
334
335
336
337
338
339
    out = causal_conv1d_fn(x.squeeze(0),
                           weight,
                           bias=bias,
                           conv_states=final_states,
                           query_start_loc=cumsum.cuda(),
                           cache_indices=padded_state_indices,
                           has_initial_state=has_initial_states,
                           activation=activation,
                           pad_slot_id=PAD_SLOT_ID)
340

341
342
343
344
345
346
    out_ref = []
    out_ref_b = []

    splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
    for i in range(len(seqlens[0])):
        x_s = [v[i].unsqueeze(0) for v in splits][0]
347
348
        if padded_state_indices[i] == PAD_SLOT_ID:
            continue
349
350
351
352
353
354
355
        out_ref_b.append(
            causal_conv1d_ref(
                x_s,
                weight_ref,
                bias_ref,
                activation=activation,
                return_final_states=True,
356
357
358
359
                final_states_out=final_states_ref[
                    padded_state_indices[i]].unsqueeze(0),
                initial_states=final_states_ref[padded_state_indices[i]].
                unsqueeze(0) if has_initial_states[i] else None))
360
    out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
361
362
    out_ref_tensor = torch.cat(out_ref, dim=0)

363
364
365
366
    assert torch.allclose(final_states[state_indices],
                          final_states_ref[state_indices],
                          rtol=rtol,
                          atol=atol)
367
368
    unpadded_out = out[:, :out_ref_tensor.shape[-1]]
    assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)