_aiter_ops.py 68.6 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable

import torch
7
from torch._ops import OpOverload
8
9
10

import vllm.envs as envs
from vllm.platforms import current_platform
11
from vllm.utils.import_utils import PlaceholderModule
12
from vllm.utils.torch_utils import direct_register_custom_op
13
14
15
16
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
    rocm_aiter_sparse_attn_indexer,
    rocm_aiter_sparse_attn_indexer_fake,
)
17

18
19
20
21
22
try:
    import pandas as pd
except ImportError:
    pd = PlaceholderModule("pandas")

23
24
25
26
# fp8_dtype is not cached.
# on ROCm the fp8_dtype always calls is_fp8_fnuz
# which is a host op, so we cache it once here.
FP8_DTYPE = current_platform.fp8_dtype()
vllmellm's avatar
vllmellm committed
27

28
29
30
31
32
33
34
35
36
37
38
39
40

def is_aiter_found() -> bool:
    from importlib.util import find_spec

    return find_spec("aiter") is not None


# `find_spec` is not torch.compile compatible.
# In cases where aiter availability might have
# been checked in forward passes that are torch compiled.
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found()

41

42
def is_aiter_found_and_supported() -> bool:
43
    """Check if AITER library is available and platform supports it.
44

45
46
47
48
49
50
51
52
53
54
55
56
    Checks: platform (ROCm), device arch (gfx9), and library existence.
    Does NOT check environment variables - that's handled by rocm_aiter_ops.is_enabled().

    This function determines if aiter CAN be used, not if it SHOULD be used.

    Separation of concerns:
    - This function: Can aiter work on this system? (platform + library availability)
    - rocm_aiter_ops.is_enabled(): Should aiter be used by default? (adds env var check)
    - Backend selection: Can explicitly request aiter regardless of env var

    This allows explicit backend selection via attention_config to work even when
    VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery.
57
    """
58
    if current_platform.is_rocm() and IS_AITER_FOUND:
59
        from vllm.platforms.rocm import on_mi3xx
60

61
        return on_mi3xx()
62
    return False
63

64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@functools.cache
def _load_gemm_tuned_configs(
    q_dtype_w: torch.dtype, csv_path: str
) -> set[tuple[int, int, int]]:
    try:
        df = pd.read_csv(csv_path).drop_duplicates()
        df = df[df["q_dtype_w"] == str(q_dtype_w)]
        return set(zip(df["N"].astype(int), df["K"].astype(int), df["M"].astype(int)))
    except Exception:
        return set()


def _check_kernel_tuned(N: int, K: int, q_dtype_w: torch.dtype, csv_path: str) -> bool:
    configs = _load_gemm_tuned_configs(q_dtype_w, csv_path)
    l_m = (
        [1, 2, 4]
        + list(range(8, 513, 8))
        + [1024, 1536]
        + [2**i for i in range(11, 19)]
    )
    return any((N, K, M) in configs for M in l_m)


88
89
def if_aiter_supported(func: Callable) -> Callable:
    """Decorator that only executes the function if
90
    ROCm AITER package is supported and enabled on gfx9 archs.
91
92
93
94
    """

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
95
96
        if is_aiter_found_and_supported():
            return func(*args, **kwargs)
97
98

        return None
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    return wrapper


def _rocm_aiter_fused_moe_impl(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    expert_mask: torch.Tensor | None = None,
    activation_method: int = 0,
    quant_method: int = 0,
    doweight_stage1: bool = False,
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
117
118
    num_local_tokens: torch.Tensor | None = None,
    output_dtype: torch.dtype | None = None,
119
120
121
122
    hidden_pad: int = 0,
    intermediate_pad: int = 0,
    bias1: torch.Tensor | None = None,
    bias2: torch.Tensor | None = None,
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
) -> torch.Tensor:
    from aiter import ActivationType, QuantType
    from aiter.fused_moe import fused_moe

    activation = ActivationType(activation_method)
    quant_type = QuantType(quant_method)

    return fused_moe(
        hidden_states,
        w1,
        w2,
        topk_weight,
        topk_ids,
        expert_mask,
        activation,
        quant_type,
        doweight_stage1,
        w1_scale,
        w2_scale,
        a1_scale,
        a2_scale,
144
145
        num_local_tokens=num_local_tokens,
        dtype=output_dtype,
146
147
148
149
        hidden_pad=hidden_pad,
        intermediate_pad=intermediate_pad,
        bias1=bias1,
        bias2=bias2,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    )


def _rocm_aiter_fused_moe_fake(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    expert_mask: torch.Tensor | None = None,
    activation_method: int = 0,
    quant_method: int = 0,
    doweight_stage1: bool = False,
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
167
168
    num_local_tokens: torch.Tensor | None = None,
    output_dtype: torch.dtype | None = None,
169
170
171
172
    hidden_pad: int = 0,
    intermediate_pad: int = 0,
    bias1: torch.Tensor | None = None,
    bias2: torch.Tensor | None = None,
173
) -> torch.Tensor:
174
175
    if output_dtype is not None:
        return torch.empty_like(hidden_states, dtype=output_dtype)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    return torch.empty_like(hidden_states)


def _rocm_aiter_asm_moe_tkw1_impl(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    fc1_scale: torch.Tensor | None = None,
    fc2_scale: torch.Tensor | None = None,
    fc1_smooth_scale: torch.Tensor | None = None,
    fc2_smooth_scale: torch.Tensor | None = None,
    a16: bool = False,
    per_tensor_quant_scale: torch.Tensor | None = None,
    expert_mask: torch.Tensor | None = None,
    activation_method: int = 0,
) -> torch.Tensor:
    from aiter import ActivationType
    from aiter.fused_moe_bf16_asm import asm_moe_tkw1

    activation = ActivationType(activation_method)

    return asm_moe_tkw1(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        fc1_scale=fc1_scale,
        fc2_scale=fc2_scale,
        fc1_smooth_scale=fc1_smooth_scale,
        fc2_smooth_scale=fc2_smooth_scale,
        a16=a16,
        per_tensor_quant_scale=per_tensor_quant_scale,
        expert_mask=expert_mask,
        activation=activation,
    )


def _rocm_aiter_asm_moe_tkw1_fake(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    fc1_scale: torch.Tensor | None = None,
    fc2_scale: torch.Tensor | None = None,
    fc1_smooth_scale: torch.Tensor | None = None,
    fc2_smooth_scale: torch.Tensor | None = None,
    a16: bool = False,
    per_tensor_quant_scale: torch.Tensor | None = None,
    expert_mask: torch.Tensor | None = None,
    activation_method: int = 0,
) -> torch.Tensor:
    return torch.empty_like(hidden_states)


def _rocm_aiter_topk_softmax_impl(
    topk_weights: torch.Tensor,
    topk_indices: torch.Tensor,
    token_expert_indices: torch.Tensor,
    gating_output: torch.Tensor,
    renormalize: bool,
) -> None:
    from aiter import topk_softmax

    topk_softmax(
        topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
    )


def _rocm_aiter_topk_softmax_fake(
    topk_weights: torch.Tensor,
    topk_indices: torch.Tensor,
    token_expert_indices: torch.Tensor,
    gating_output: torch.Tensor,
    renormalize: bool,
) -> None:
    pass


258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def _rocm_aiter_topk_sigmoid_impl(
    topk_weights: torch.Tensor,
    topk_indices: torch.Tensor,
    gating_output: torch.Tensor,
) -> None:
    from aiter import topk_sigmoid

    topk_sigmoid(topk_weights, topk_indices, gating_output)


def _rocm_aiter_topk_sigmoid_fake(
    topk_weights: torch.Tensor,
    topk_indices: torch.Tensor,
    gating_output: torch.Tensor,
) -> None:
    pass


276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def _rocm_aiter_biased_grouped_topk_impl(
    gating_output: torch.Tensor,
    correction_bias: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    num_expert_group: int,
    topk_group: int,
    need_renorm: bool,
    routed_scaling_factor: float = 1.0,  # mul to topk_weights
) -> None:
    from aiter import biased_grouped_topk

    biased_grouped_topk(
        gating_output,
        correction_bias,
        topk_weights,
        topk_ids,
        num_expert_group,
        topk_group,
        need_renorm,
        routed_scaling_factor,
    )


def _rocm_aiter_biased_grouped_topk_fake(
    gating_output: torch.Tensor,
    correction_bias: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    num_expert_group: int,
    topk_group: int,
    need_renorm: bool,
    routed_scaling_factor: float = 1.0,  # mul to topk_weights
) -> None:
    pass


def _rocm_aiter_grouped_topk_impl(
    gating_output: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    num_expert_group: int,
    topk_group: int,
    need_renorm: bool,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,  # mul to topk_weights
) -> None:
    is_softmax = scoring_func == "softmax"
    from aiter import grouped_topk

    grouped_topk(
        gating_output,
        topk_weights,
        topk_ids,
        num_expert_group,
        topk_group,
        need_renorm,
        is_softmax,
        routed_scaling_factor,
    )


def _rocm_aiter_grouped_topk_fake(
    gating_output: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    num_expert_group: int,
    topk_group: int,
    need_renorm: bool,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,  # mul to topk_weights
) -> None:
    pass


351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def _rocm_aiter_fused_topk_impl(
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    gate_up: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
    from aiter.fused_moe import fused_topk

    # fused_topk returns (topk_weights, topk_indices)
    return fused_topk(x, router_logits, top_k, gate_up)


def _rocm_aiter_fused_topk_fake(
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    gate_up: bool,
368
369
370
371
372
373
374
) -> tuple[torch.Tensor, torch.Tensor]:
    num_tokens = x.shape[0]
    topk_weights = torch.empty(
        (num_tokens, top_k), dtype=torch.float32, device=x.device
    )
    topk_indices = torch.empty((num_tokens, top_k), dtype=torch.int32, device=x.device)
    return topk_weights, topk_indices
375
376


377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None


def _check_aiter_mla_fp8_support() -> bool:
    """Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
    global _AITER_MLA_SUPPORTS_FP8
    if _AITER_MLA_SUPPORTS_FP8 is None:
        try:
            import inspect

            from aiter.mla import mla_decode_fwd

            sig = inspect.signature(mla_decode_fwd)
            _AITER_MLA_SUPPORTS_FP8 = (
                "q_scale" in sig.parameters and "kv_scale" in sig.parameters
            )
394
395
396
397
398
399
400
401
402
403
404
        except (
            ImportError,
            ModuleNotFoundError,
            AttributeError,
            ValueError,
            TypeError,
        ):
            # ImportError/ModuleNotFoundError: aiter.mla module not available
            # AttributeError: mla_decode_fwd doesn't exist
            # ValueError: mla_decode_fwd has no signature (e.g., built-in)
            # TypeError: mla_decode_fwd is not a callable
405
406
407
408
            _AITER_MLA_SUPPORTS_FP8 = False
    return _AITER_MLA_SUPPORTS_FP8


409
410
411
412
413
414
415
416
417
418
419
def _rocm_aiter_mla_decode_fwd_impl(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: torch.Tensor | None = None,
    kv_indices: torch.Tensor | None = None,
    kv_last_page_lens: torch.Tensor | None = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
420
421
    q_scale: torch.Tensor | None = None,
    kv_scale: torch.Tensor | None = None,
422
423
424
425
426
427
    work_meta_data: torch.Tensor | None = None,
    work_indptr: torch.Tensor | None = None,
    work_info_set: torch.Tensor | None = None,
    reduce_indptr: torch.Tensor | None = None,
    reduce_final_map: torch.Tensor | None = None,
    reduce_partial_map: torch.Tensor | None = None,
428
429
430
) -> None:
    from aiter.mla import mla_decode_fwd

431
    kwargs: dict[str, float | torch.Tensor | None] = {
432
433
434
435
436
437
438
439
440
        "sm_scale": sm_scale,
        "logit_cap": logit_cap,
    }

    # Only pass q_scale and kv_scale if the aiter library supports them
    if _check_aiter_mla_fp8_support():
        kwargs["q_scale"] = q_scale
        kwargs["kv_scale"] = kv_scale

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    if work_meta_data is not None:
        assert work_indptr is not None, (
            "work_indptr must be provided with work_meta_data"
        )
        assert work_info_set is not None, (
            "work_info_set must be provided with work_meta_data"
        )
        assert reduce_indptr is not None, (
            "reduce_indptr must be provided with work_meta_data"
        )
        assert reduce_final_map is not None, (
            "reduce_final_map must be provided with work_meta_data"
        )
        assert reduce_partial_map is not None, (
            "reduce_partial_map must be provided with work_meta_data"
        )
        kwargs["work_meta_data"] = work_meta_data
        kwargs["work_indptr"] = work_indptr
        kwargs["work_info_set"] = work_info_set
        kwargs["reduce_indptr"] = reduce_indptr
        kwargs["reduce_final_map"] = reduce_final_map
        kwargs["reduce_partial_map"] = reduce_partial_map

464
465
466
467
468
469
470
471
472
    mla_decode_fwd(
        q,
        kv_buffer.view(-1, 1, 1, q.shape[-1]),
        o,
        qo_indptr,
        kv_indptr,
        kv_indices,
        kv_last_page_lens,
        max_seqlen_qo,
473
        **kwargs,
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    )


def _rocm_aiter_mla_decode_fwd_fake(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: torch.Tensor | None = None,
    kv_indices: torch.Tensor | None = None,
    kv_last_page_lens: torch.Tensor | None = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
488
489
    q_scale: torch.Tensor | None = None,
    kv_scale: torch.Tensor | None = None,
490
491
492
493
494
495
    work_meta_data: torch.Tensor | None = None,
    work_indptr: torch.Tensor | None = None,
    work_info_set: torch.Tensor | None = None,
    reduce_indptr: torch.Tensor | None = None,
    reduce_final_map: torch.Tensor | None = None,
    reduce_partial_map: torch.Tensor | None = None,
496
497
498
499
) -> None:
    pass


500
def _rocm_aiter_w8a8_gemm_impl(
501
502
503
504
505
506
507
508
509
510
511
512
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    bias: torch.Tensor | None = None,
    output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    from aiter import gemm_a8w8_CK

    # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
    # a to be [M, K]
    # b to be [N, K]
513
    # CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
514
515
516
    return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)


517
def _rocm_aiter_w8a8_gemm_fake(
518
519
520
521
522
523
524
525
526
527
528
529
530
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    bias: torch.Tensor | None = None,
    output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    m = A.shape[0]
    n = B.shape[0]
    Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
    return Y


531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
def _rocm_aiter_preshuffled_per_token_w8a8_gemm_impl(
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    bias: torch.Tensor | None = None,
    output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    from aiter import gemm_a8w8_bpreshuffle

    output = gemm_a8w8_bpreshuffle(A, B, As, Bs, None, output_dtype)
    if bias is not None:
        output.add_(bias)
    return output


def _rocm_aiter_preshuffled_per_token_w8a8_gemm_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    bias: torch.Tensor | None = None,
    output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    m = A.shape[0]
    n = B.shape[0]
    return torch.empty(m, n, dtype=output_dtype, device=A.device)


560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale

    return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)


def _rocm_aiter_triton_gemm_a8w8_blockscale_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    m = A.shape[0]
    n = B.shape[0]
    Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
    return Y


585
def _rocm_aiter_gemm_a8w8_blockscale_impl(
586
587
588
589
590
591
592
593
594
595
596
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    from aiter import gemm_a8w8_blockscale

    return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)


597
def _rocm_aiter_gemm_a8w8_blockscale_fake(
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    m = A.shape[0]
    n = B.shape[0]
    Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
    return Y


def _rocm_aiter_rms_norm_impl(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
    from aiter import rms_norm

    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

    return rms_norm(x, weight, variance_epsilon)


def _rocm_aiter_rms_norm_fake(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
    return torch.empty_like(x)


def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    from aiter import rmsnorm2d_fwd_with_add

    residual_out = torch.empty_like(residual)
639
    out = torch.empty_like(x)
640
    rmsnorm2d_fwd_with_add(
641
        out,  # output
642
643
644
645
646
647
        x,  # input
        residual,  # residual input
        residual_out,  # residual output
        weight,
        variance_epsilon,
    )
648
    return out, residual_out
649
650
651
652
653
654
655
656


def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
657
658
659
660
661
662
663
664
665
666
667
668
669
670
    residual_out = torch.empty_like(residual)
    out = torch.empty_like(x)
    return out, residual_out


def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl(
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    import aiter as rocm_aiter

671
    assert quant_dtype in [torch.int8, FP8_DTYPE]
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712

    y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
    out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
    residual_out = torch.empty_like(x)

    rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant(
        out,
        x,
        residual,
        residual_out,
        y_scale,
        weight,
        epsilon,
        use_model_sensitive_rmsnorm=0,
    )

    return out, residual_out, y_scale


def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake(
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
    out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
    residual_out = torch.empty_like(x)

    return out, residual_out, y_scale


def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl(
    x: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
    import aiter as rocm_aiter

713
    assert quant_dtype in [torch.int8, FP8_DTYPE]
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734

    y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
    out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)

    rocm_aiter.rmsnorm2d_fwd_with_dynamicquant(
        out, x, y_scale, weight, epsilon, use_model_sensitive_rmsnorm=0
    )

    return out, y_scale


def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake(
    x: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
    y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
    out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)

    return out, y_scale
735
736


vllmellm's avatar
vllmellm committed
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
def _rocm_aiter_per_tensor_quant_impl(
    x: torch.Tensor,
    quant_dtype: torch.dtype,
    scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    from aiter.ops.quant import per_tensor_quant_hip

    return per_tensor_quant_hip(x, scale, quant_dtype)


def _rocm_aiter_per_tensor_quant_fake(
    x: torch.Tensor,
    quant_dtype: torch.dtype,
    scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.empty_like(x, dtype=quant_dtype), torch.empty(
        1, dtype=torch.float32, device=x.device
    )


def _rocm_aiter_per_token_quant_impl(
    x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
    from aiter.ops.quant import dynamic_per_token_scaled_quant

762
    assert quant_dtype in [torch.int8, FP8_DTYPE]
vllmellm's avatar
vllmellm committed
763
764

    out_shape = x.shape
765
    out = torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device)
vllmellm's avatar
vllmellm committed
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
    if scale is None:
        scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
    dynamic_per_token_scaled_quant(
        out,
        x,
        scale,
        scale_ub=None,
        shuffle_scale=False,
        num_rows=None,
        num_rows_factor=1,
    )
    return out, scale


def _rocm_aiter_per_token_quant_fake(
    x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
    out_shape = x.shape
    return (
785
        torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device),
vllmellm's avatar
vllmellm committed
786
787
788
789
        torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
    )


790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
    group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant

    (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
        x,
        weight,
        variance_epsilon,
        None,
        None,
        None,
        group_size=group_size,
807
        dtype_quant=FP8_DTYPE,
808
809
        res1=residual,
    )
810
811
812
813
814
    return (
        x_quant,
        res,
        x_quant_scales,
    )
815
816
817
818
819
820
821
822
823
824
825
826


def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
    group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    M, N = x.shape
    scale_shape = (M, (N + group_size - 1) // group_size)
    return (
827
        torch.empty_like(x, dtype=FP8_DTYPE, device=x.device),
828
        torch.empty_like(residual, device=residual.device),
829
        torch.empty(scale_shape, dtype=torch.float32, device=x.device),
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
    )


def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
    x: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
    group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant

    (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
        x,
        weight,
        variance_epsilon,
        None,
        None,
        None,
        group_size=group_size,
849
        dtype_quant=FP8_DTYPE,
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        res1=None,
    )
    return (x_quant, x_quant_scales)


def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
    x: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
    group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    M, N = x.shape
    scale_shape = (M, (N + group_size - 1) // group_size)
    return (
864
        torch.empty_like(x, dtype=FP8_DTYPE, device=x.device),
865
866
867
868
869
870
871
872
873
874
875
876
        torch.empty(scale_shape, dtype=torch.float32, device=x.device),
    )


def _rocm_aiter_group_fp8_quant_impl(
    x: torch.Tensor,
    group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
    from aiter import QuantType, get_hip_quant

    aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
877
    return aiter_per1x128_quant(x.contiguous(), quant_dtype=FP8_DTYPE)
878
879
880
881
882
883
884


def _rocm_aiter_group_fp8_quant_fake(
    x: torch.Tensor,
    group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    M, N = x.shape
885
    x_fp8 = torch.empty((M, N), dtype=FP8_DTYPE, device=x.device)
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
    out_bs = torch.empty(
        (
            M,
            (N + group_size - 1) // group_size,
        ),
        dtype=torch.float32,
        device=x.device,
    )
    return x_fp8, out_bs


def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
    x: torch.Tensor,
    group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    from aiter.ops.triton.activation import act_mul_and_fp8_group_quant

    return act_mul_and_fp8_group_quant(
        x,
        activation="silu",
        group_size=group_size,
907
        dtype_quant=FP8_DTYPE,
908
909
910
911
912
913
914
915
916
917
    )


def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
    x: torch.Tensor,
    group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    M, N = x.shape
    assert N % 2 == 0
    N_half = N // 2
918
    x_fp8 = torch.empty((M, N_half), dtype=FP8_DTYPE, device=x.device)
919
920
921
922
923
924
925
926
927
928
929
    out_bs = torch.empty(
        (
            M,
            (N_half + group_size - 1) // group_size,
        ),
        dtype=torch.float32,
        device=x.device,
    )
    return x_fp8, out_bs


930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
def _rocm_aiter_triton_add_rmsnorm_pad_impl(
    x: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
    residual: torch.Tensor,
    x_pad_to_multiple: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad

    return fused_add_rmsnorm_pad(
        x,
        weight,
        variance_epsilon,
        residual,
        x_pad_to_multiple=x_pad_to_multiple,
    )


def _rocm_aiter_triton_add_rmsnorm_pad_fake(
    x: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
    residual: torch.Tensor,
    x_pad_to_multiple: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    M, N = x.shape
    if x_pad_to_multiple > 0:
        N_out = (N + x_pad_to_multiple - 1) // x_pad_to_multiple * x_pad_to_multiple
    else:
        N_out = N
    out = torch.empty((M, N_out), dtype=x.dtype, device=x.device)
    residual_out = torch.empty_like(residual)
    return out, residual_out


965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
def _rocm_aiter_gemm_a8wfp4_impl(
    x: torch.Tensor,
    w: torch.Tensor,
    x_scales: torch.Tensor,
    w_scales: torch.Tensor,
    out_dtype: torch.dtype,
) -> torch.Tensor:
    from aiter.ops.triton.gemm_a8wfp4 import gemm_a8wfp4

    M, N = x.shape[0], w.shape[0]
    y = torch.empty(M, N, dtype=out_dtype, device=x.device)
    gemm_a8wfp4(
        x=x,
        w=w,
        y=y,
        x_scales=x_scales,
        w_scales=w_scales,
        dtype=out_dtype,
        config=None,
    )
    return y


def _rocm_aiter_gemm_a8wfp4_fake(
    x: torch.Tensor,
    w: torch.Tensor,
    x_scales: torch.Tensor,
    w_scales: torch.Tensor,
    out_dtype: torch.dtype,
) -> torch.Tensor:
    return torch.empty(x.shape[0], w.shape[0], dtype=out_dtype, device=x.device)


998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
def _triton_rotary_embedding_impl(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
    offsets: torch.Tensor | None = None,
) -> None:
    # Modifies query and key in-place
    from aiter.ops.triton.rope.rope import (
        rope_cached_thd_positions_offsets_2c_fwd_inplace,
    )

    num_tokens = positions.numel()
    cos, sin = cos_sin_cache.chunk(2, dim=-1)
    query_shape = query.shape
    key_shape = key.shape
    rotate_style = 0 if is_neox else 1
    rotary_dim = head_size

    query = query.view(num_tokens, -1, head_size)
    key = key.view(num_tokens, -1, head_size)
    query_ = query[..., :rotary_dim]
    key_ = key[..., :rotary_dim]
    positions = positions.view(*query.shape[:1])
    rope_cached_thd_positions_offsets_2c_fwd_inplace(
        query_,
        key_,
        cos,
        sin,
        positions,
        offsets,
        rotate_style,
        reuse_freqs_front_part=True,
        nope_first=False,
    )
    query = query.view(query_shape)
    key = key.view(key_shape)


def _triton_rotary_embedding_fake(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox_style: bool,
    offsets: torch.Tensor | None = None,
) -> None:
    return


1051
1052
1053
1054
1055
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False


class rocm_aiter_ops:
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
    """ROCm AITER operations wrapper for AMD GPU acceleration in vLLM.

    This class centralizes the import and registration of AITER ops,
    and provides a unified interface for checking if AITER is enabled.
    Operations are only available on supported gfx9
    architectures when aiter is installed.

    The class uses environment variables to control which features are enabled,
    allowing fine-grained control over which AITER optimizations are used.

    Environment Variables:
        VLLM_ROCM_USE_AITER: Main toggle for all AITER operations.
        VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops.
        VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations.
        VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops.
        VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops.
        VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen.
        VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention.
        VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply.
        VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM.
        VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings.
        VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion.
        VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM.

    Note:
        The environment variables are assigned when the module is imported,
        so you can't change the environment variables after the module is imported.
        This is done out of performance consideration. Accessing environment variables
        is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067
        so we don't want to do it repeatedly, especially in the hot path (the forward pass).
        You can call the refresh_env_variables() function to reload the env variables
        after monkey patching the env variables in the unit test.

    Check Functions:
        All check functions (is_*_enabled) are decorated with @if_aiter_supported,
        which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
        (3) aiter library is installed. The check function then also verifies
        the corresponding environment variable is enabled.
        i.e.                                             ___
        is_enabled() == current_platform.is_rocm() and      |     checked by
                        current_platform.is_on_gfx9() and   | @if_aiter_supported
                        IS_AITER_FOUND and   _______________|
                        cls._AITER_ENABLED   -----> Check by the logic in `is_enabled()`

    Example:
        from vllm._aiter_ops import rocm_aiter_ops

        # Check if aiter is enabled before using operations
        if rocm_aiter_ops.is_enabled():
            result = rocm_aiter_ops.rms_norm(x, weight, epsilon)

    Operations:
        - RMS normalization: rms_norm, rms_norm2d_with_add
        - GEMM operations: gemm_a8w8, gemm_a8w8_blockscale
        - Fused MoE: fused_moe, asm_moe_tkw1
        - Routing: topk_softmax, biased_grouped_topk, grouped_topk
        - MLA decode: mla_decode_fwd
        - Quantization: per_tensor_quant, per_token_quant, group_fp8_quant
        - Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale
    """

    # Check if the env variable is set
1118
1119
1120
1121
1122
1123
    _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
    _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
    _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
    _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
    _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
    _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
1124
    _SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT
1125
    _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
1126
    # TODO: Consolidate under _LINEAR_ENABLED
1127
    _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
1128
    _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
1129
    # TODO: Consolidate under _LINEAR_ENABLED
1130
    _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
1131
    # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
1132
1133
    _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
    _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
1134
    # TODO: Consolidate under _LINEAR_ENABLED
1135
    _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
1136

1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
    @classmethod
    def refresh_env_variables(cls):
        """
        Since the environment variables are assigned when the module is imported,
        This is a helper function to reload all the env variables from
        the environment variables.
        for example, after monkey patching the env variables in the unit test,
        you can call this function to reload the env variables.
        """
        cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
        cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
        cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
        cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
        cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
        cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
1152
        cls._SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT
1153
1154
        cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
        cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
1155
        cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
1156
1157
1158
1159
1160
        cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
        cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
        cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
        cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM

1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
    @staticmethod
    def get_aiter_activation_type(activation_str: str):
        """
        Given an activation type as a string, returns the corresponding aiter ActivationType enum.
        Supported activation types: "no", "none", "silu", "gelu", "swiglu".
        Returns None if the mapping fails.

        Args:
            activation_str (str): Activation type as string.

        Returns:
            Aiter ActivationType enum value, or None if not found.
        """
        # Import only locally, since aiter may not always be available.
        try:
            from aiter import ActivationType
        except ImportError:
            return None

        if not isinstance(activation_str, str):
            return None

        name = activation_str.strip().lower()
        mapping = {
            "none": ActivationType.No,
            "no": ActivationType.No,
            "silu": ActivationType.Silu,
            "gelu": ActivationType.Gelu,
            "swiglu": ActivationType.Swiglu,
        }
        return mapping.get(name)

    @staticmethod
    def get_aiter_quant_type(quant_type_str: str):
        """
        Given a quantization type as a string, returns the corresponding aiter QuantType enum.
        Supported quantization types: "no", "per_tensor", "per_token", "per_1x32", "per_1x128", "per_128x128".
        Returns None if the mapping fails.

        Args:
            quant_type_str (str): Quantization type as string.

        Returns:
            Aiter QuantType enum value, or None if not found.
        """
        try:
            from aiter import QuantType
        except ImportError:
            return None

        if not isinstance(quant_type_str, str):
            return None

        name = quant_type_str.strip().lower()
        mapping = {
            "no": QuantType.No,
            "per_tensor": QuantType.per_Tensor,
            "per_token": QuantType.per_Token,
            "per_1x32": QuantType.per_1x32,
            "per_1x128": QuantType.per_1x128,
            "per_128x128": QuantType.per_128x128,
        }
        return mapping.get(name)

1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
    @classmethod
    @if_aiter_supported
    def is_enabled(cls) -> bool:
        return cls._AITER_ENABLED

    @classmethod
    @if_aiter_supported
    def is_linear_enabled(cls) -> bool:
        return cls._AITER_ENABLED and cls._LINEAR_ENABLED

    @classmethod
    @if_aiter_supported
1237
    def is_linear_fp8_enabled(cls) -> bool:
1238
        return cls.is_linear_enabled()
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264

    @classmethod
    @if_aiter_supported
    def is_rmsnorm_enabled(cls) -> bool:
        return cls._AITER_ENABLED and cls._RMSNORM_ENABLED

    @classmethod
    @if_aiter_supported
    def is_fused_moe_enabled(cls) -> bool:
        return cls._AITER_ENABLED and cls._FMOE_ENABLED

    @classmethod
    @if_aiter_supported
    def is_fusion_moe_shared_experts_enabled(cls) -> bool:
        return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED

    @classmethod
    @if_aiter_supported
    def is_mla_enabled(cls) -> bool:
        return cls._AITER_ENABLED and cls._MLA_ENABLED

    @classmethod
    @if_aiter_supported
    def is_mha_enabled(cls) -> bool:
        return cls._AITER_ENABLED and cls._MHA_ENABLED

1265
1266
1267
    @classmethod
    @if_aiter_supported
    def is_shuffle_kv_cache_enabled(cls) -> bool:
1268
        return cls._SHUFFLE_KV_CACHE_ENABLED
1269

1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
    @classmethod
    @if_aiter_supported
    def is_triton_unified_attn_enabled(cls) -> bool:
        return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED

    @classmethod
    @if_aiter_supported
    def is_fp8bmm_enabled(cls) -> bool:
        return cls._AITER_ENABLED and cls._FP8BMM_ENABLED

1280
1281
1282
    @classmethod
    @if_aiter_supported
    def is_fp4bmm_enabled(cls) -> bool:
1283
1284
1285
        from vllm.platforms.rocm import on_gfx950

        return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950()
1286

1287
1288
1289
    @classmethod
    @if_aiter_supported
    def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
1290
1291
1292
        from vllm.platforms.rocm import on_gfx950

        return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM and on_gfx950()
1293
1294
1295
1296
1297
1298

    @classmethod
    @if_aiter_supported
    def is_triton_rotary_embed_enabled(cls) -> bool:
        return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED

1299
1300
1301
1302
1303
    @classmethod
    @if_aiter_supported
    def is_triton_gemm_enabled(cls) -> bool:
        return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM

1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
    @staticmethod
    @if_aiter_supported
    def register_ops_once() -> None:
        global _OPS_REGISTERED
        if not _OPS_REGISTERED:
            # register all the custom ops here
            direct_register_custom_op(
                op_name="rocm_aiter_asm_moe_tkw1",
                op_func=_rocm_aiter_asm_moe_tkw1_impl,
                mutates_args=[],
                fake_impl=_rocm_aiter_asm_moe_tkw1_fake,
                dispatch_key=current_platform.dispatch_key,
            )

            direct_register_custom_op(
                op_name="rocm_aiter_fused_moe",
                op_func=_rocm_aiter_fused_moe_impl,
                mutates_args=[],
                fake_impl=_rocm_aiter_fused_moe_fake,
                dispatch_key=current_platform.dispatch_key,
            )

            direct_register_custom_op(
                op_name="rocm_aiter_topk_softmax",
                op_func=_rocm_aiter_topk_softmax_impl,
                mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
                fake_impl=_rocm_aiter_topk_softmax_fake,
                dispatch_key=current_platform.dispatch_key,
            )

1334
1335
1336
1337
1338
1339
1340
1341
            direct_register_custom_op(
                op_name="rocm_aiter_topk_sigmoid",
                op_func=_rocm_aiter_topk_sigmoid_impl,
                mutates_args=["topk_weights", "topk_indices"],
                fake_impl=_rocm_aiter_topk_sigmoid_fake,
                dispatch_key=current_platform.dispatch_key,
            )

1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
            direct_register_custom_op(
                op_name="rocm_aiter_biased_grouped_topk",
                op_func=_rocm_aiter_biased_grouped_topk_impl,
                mutates_args=["topk_weights", "topk_ids"],
                fake_impl=_rocm_aiter_biased_grouped_topk_fake,
                dispatch_key=current_platform.dispatch_key,
            )

            direct_register_custom_op(
                op_name="rocm_aiter_grouped_topk",
                op_func=_rocm_aiter_grouped_topk_impl,
                mutates_args=["topk_weights", "topk_ids"],
                fake_impl=_rocm_aiter_grouped_topk_fake,
                dispatch_key=current_platform.dispatch_key,
            )

1358
1359
1360
1361
1362
1363
1364
1365
            direct_register_custom_op(
                op_name="rocm_aiter_fused_topk",
                op_func=_rocm_aiter_fused_topk_impl,
                mutates_args=[],
                fake_impl=_rocm_aiter_fused_topk_fake,
                dispatch_key=current_platform.dispatch_key,
            )

1366
1367
1368
1369
1370
1371
1372
1373
            direct_register_custom_op(
                op_name="rocm_aiter_mla_decode_fwd",
                op_func=_rocm_aiter_mla_decode_fwd_impl,
                mutates_args=["o"],
                fake_impl=_rocm_aiter_mla_decode_fwd_fake,
            )

            direct_register_custom_op(
1374
1375
1376
1377
1378
1379
1380
1381
1382
                op_name="rocm_aiter_w8a8_gemm",
                op_func=_rocm_aiter_w8a8_gemm_impl,
                fake_impl=_rocm_aiter_w8a8_gemm_fake,
            )

            direct_register_custom_op(
                op_name="_rocm_aiter_preshuffled_per_token_w8a8_gemm",
                op_func=_rocm_aiter_preshuffled_per_token_w8a8_gemm_impl,
                fake_impl=_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake,
1383
1384
            )

1385
1386
1387
1388
1389
1390
            direct_register_custom_op(
                op_name="rocm_aiter_triton_gemm_a8w8_blockscale",
                op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl,
                fake_impl=_rocm_aiter_triton_gemm_a8w8_blockscale_fake,
            )

1391
            direct_register_custom_op(
1392
1393
1394
                op_name="rocm_aiter_gemm_a8w8_blockscale",
                op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
                fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
            )

            direct_register_custom_op(
                op_name="rocm_aiter_rms_norm",
                op_func=_rocm_aiter_rms_norm_impl,
                fake_impl=_rocm_aiter_rms_norm_fake,
            )

            direct_register_custom_op(
                op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
                op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
                fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
                dispatch_key=current_platform.dispatch_key,
            )

1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
            direct_register_custom_op(
                op_name="rocm_aiter_rmsnorm_fused_dynamic_quant",
                op_func=_rocm_aiter_rmsnorm_fused_dynamic_quant_impl,
                fake_impl=_rocm_aiter_rmsnorm_fused_dynamic_quant_fake,
                dispatch_key=current_platform.dispatch_key,
            )

            direct_register_custom_op(
                op_name="rocm_aiter_rmsnorm_fused_add_dynamic_quant",
                op_func=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl,
                fake_impl=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake,
                dispatch_key=current_platform.dispatch_key,
            )

1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
            direct_register_custom_op(
                op_name="rocm_aiter_rmsnorm_fp8_group_quant",
                op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
                fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
            )

            direct_register_custom_op(
                op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
                op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
                fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
            )

            direct_register_custom_op(
                op_name="rocm_aiter_act_mul_and_fp8_group_quant",
                op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
                fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
            )

1442
1443
1444
1445
1446
1447
1448
            direct_register_custom_op(
                op_name="rocm_aiter_triton_add_rmsnorm_pad",
                op_func=_rocm_aiter_triton_add_rmsnorm_pad_impl,
                fake_impl=_rocm_aiter_triton_add_rmsnorm_pad_fake,
                dispatch_key=current_platform.dispatch_key,
            )

1449
1450
1451
1452
1453
1454
            direct_register_custom_op(
                op_name="rocm_aiter_group_fp8_quant",
                op_func=_rocm_aiter_group_fp8_quant_impl,
                fake_impl=_rocm_aiter_group_fp8_quant_fake,
            )

vllmellm's avatar
vllmellm committed
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
            direct_register_custom_op(
                op_name="rocm_aiter_per_tensor_quant",
                op_func=_rocm_aiter_per_tensor_quant_impl,
                mutates_args=[],
                fake_impl=_rocm_aiter_per_tensor_quant_fake,
                dispatch_key=current_platform.dispatch_key,
            )

            direct_register_custom_op(
                op_name="rocm_aiter_per_token_quant",
                op_func=_rocm_aiter_per_token_quant_impl,
                fake_impl=_rocm_aiter_per_token_quant_fake,
                dispatch_key=current_platform.dispatch_key,
            )

1470
1471
1472
1473
1474
1475
1476
1477
            direct_register_custom_op(
                op_name="rocm_aiter_sparse_attn_indexer",
                op_func=rocm_aiter_sparse_attn_indexer,
                mutates_args=["topk_indices_buffer"],
                fake_impl=rocm_aiter_sparse_attn_indexer_fake,
                dispatch_key=current_platform.dispatch_key,
            )

1478
1479
1480
1481
1482
1483
1484
1485
            direct_register_custom_op(
                op_name="rocm_aiter_gemm_a8wfp4",
                op_func=_rocm_aiter_gemm_a8wfp4_impl,
                mutates_args=[],
                fake_impl=_rocm_aiter_gemm_a8wfp4_fake,
                dispatch_key=current_platform.dispatch_key,
            )

1486
1487
1488
1489
1490
1491
1492
1493
            # Register rocm aiter rotary embedding custom op
            direct_register_custom_op(
                op_name="rocm_aiter_triton_rotary_embedding",
                op_func=_triton_rotary_embedding_impl,
                mutates_args=["query", "key"],  # These tensors are modified in-place
                fake_impl=_triton_rotary_embedding_fake,
            )

1494
1495
            _OPS_REGISTERED = True

1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
    @staticmethod
    def get_rmsnorm_fused_add_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default

    @staticmethod
    def get_rmsnorm_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_rms_norm.default

    @staticmethod
    def get_rmsnorm_fused_add_dynamic_quant_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default

    @staticmethod
    def get_rmsnorm_fused_dynamic_quant_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_rmsnorm_fused_dynamic_quant.default

    @staticmethod
    def get_rmsnorm_group_fused_quant_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default

    @staticmethod
    def get_rmsnorm_group_add_fused_quant_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default

    @staticmethod
    def get_per_token_quant_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_per_token_quant.default

    @staticmethod
    def get_group_quant_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_group_fp8_quant.default

    @staticmethod
    def get_act_mul_fused_fp8_group_quant_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default

1532
1533
1534
1535
    @staticmethod
    def get_triton_add_rmsnorm_pad_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_triton_add_rmsnorm_pad.default

1536
1537
1538
1539
    @staticmethod
    def get_triton_rotary_embedding_op() -> OpOverload:
        return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default

1540
1541
1542
1543
1544
1545
    @staticmethod
    def rms_norm(
        x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
    ) -> torch.Tensor:
        return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)

1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
    @staticmethod
    def rms_norm2d_with_add(
        x: torch.Tensor,
        residual: torch.Tensor,
        weight: torch.Tensor,
        variance_epsilon: float,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add(
            x, residual, weight, variance_epsilon
        )

    @staticmethod
1558
    def w8a8_gemm(
1559
1560
1561
1562
1563
1564
1565
        A: torch.Tensor,
        B: torch.Tensor,
        As: torch.Tensor,
        Bs: torch.Tensor,
        bias: torch.Tensor | None = None,
        output_dtype: torch.dtype = torch.float16,
    ) -> torch.Tensor:
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
        return torch.ops.vllm.rocm_aiter_w8a8_gemm(A, B, As, Bs, bias, output_dtype)

    @staticmethod
    def preshuffled_per_token_w8a8_gemm(
        A: torch.Tensor,
        B: torch.Tensor,
        As: torch.Tensor,
        Bs: torch.Tensor,
        bias: torch.Tensor | None = None,
        output_dtype: torch.dtype = torch.float16,
    ) -> torch.Tensor:
        return torch.ops.vllm._rocm_aiter_preshuffled_per_token_w8a8_gemm(
            A, B, As, Bs, bias, output_dtype
        )
1580

1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
    @staticmethod
    def triton_gemm_a8w8_blockscale(
        A: torch.Tensor,
        B: torch.Tensor,
        As: torch.Tensor,
        Bs: torch.Tensor,
        block_size: list[int],
        output_dtype: torch.dtype = torch.float16,
    ) -> torch.Tensor:
        return torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale(
            A, B, As, Bs, output_dtype
        )

1594
    @staticmethod
1595
    def gemm_a8w8_blockscale(
1596
1597
1598
1599
1600
1601
1602
        A: torch.Tensor,
        B: torch.Tensor,
        As: torch.Tensor,
        Bs: torch.Tensor,
        block_size: list[int],
        output_dtype: torch.dtype = torch.float16,
    ) -> torch.Tensor:
1603
        return torch.ops.vllm.rocm_aiter_gemm_a8w8_blockscale(
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
            A, B, As, Bs, output_dtype
        )

    @staticmethod
    def fused_moe(
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weight: torch.Tensor,
        topk_ids: torch.Tensor,
        expert_mask: torch.Tensor | None = None,
        activation_method: int = 0,
        quant_method: int = 0,
        doweight_stage1: bool = False,
        w1_scale: torch.Tensor | None = None,
        w2_scale: torch.Tensor | None = None,
        a1_scale: torch.Tensor | None = None,
        a2_scale: torch.Tensor | None = None,
1622
1623
        num_local_tokens: torch.Tensor | None = None,
        output_dtype: torch.dtype | None = None,
1624
1625
1626
1627
        hidden_pad: int = 0,
        intermediate_pad: int = 0,
        bias1: torch.Tensor | None = None,
        bias2: torch.Tensor | None = None,
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
    ) -> torch.Tensor:
        return torch.ops.vllm.rocm_aiter_fused_moe(
            hidden_states,
            w1,
            w2,
            topk_weight,
            topk_ids,
            expert_mask,
            activation_method,
            quant_method,
            doweight_stage1,
            w1_scale,
            w2_scale,
            a1_scale,
            a2_scale,
1643
1644
            num_local_tokens,
            output_dtype,
1645
1646
1647
1648
            hidden_pad,
            intermediate_pad,
            bias1,
            bias2,
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
        )

    @staticmethod
    def asm_moe_tkw1(
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        fc1_scale: torch.Tensor | None = None,
        fc2_scale: torch.Tensor | None = None,
        fc1_smooth_scale: torch.Tensor | None = None,
        fc2_smooth_scale: torch.Tensor | None = None,
        a16: bool = False,
        per_tensor_quant_scale: torch.Tensor | None = None,
        expert_mask: torch.Tensor | None = None,
        activation_method: int = 0,
    ) -> torch.Tensor:
        return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
            hidden_states,
            w1,
            w2,
            topk_weights,
            topk_ids,
            fc1_scale,
            fc2_scale,
            fc1_smooth_scale,
            fc2_smooth_scale,
            a16,
            per_tensor_quant_scale,
            expert_mask,
            activation_method,
        )

    @staticmethod
    def topk_softmax(
        topk_weights: torch.Tensor,
        topk_indices: torch.Tensor,
        token_expert_indices: torch.Tensor,
        gating_output: torch.Tensor,
        renormalize: bool,
    ) -> tuple[torch.Tensor, ...]:
        torch.ops.vllm.rocm_aiter_topk_softmax(
            topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
        )
        return topk_weights, topk_indices

1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
    @staticmethod
    def topk_sigmoid(
        topk_weights: torch.Tensor,
        topk_indices: torch.Tensor,
        token_expert_indices: torch.Tensor,
        gating_output: torch.Tensor,
        renormalize: bool,
    ) -> tuple[torch.Tensor, ...]:
        torch.ops.vllm.rocm_aiter_topk_sigmoid(
            topk_weights, topk_indices, gating_output
        )
        return topk_weights, topk_indices

1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
    @staticmethod
    def biased_grouped_topk(
        gating_output: torch.Tensor,
        correction_bias: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        num_expert_group: int,
        topk_group: int,
        need_renorm: bool,
        routed_scaling_factor: float = 1.0,
    ) -> None:
        torch.ops.vllm.rocm_aiter_biased_grouped_topk(
            gating_output,
            correction_bias,
            topk_weights,
            topk_ids,
            num_expert_group,
            topk_group,
            need_renorm,
            routed_scaling_factor,
        )

    @staticmethod
    def grouped_topk(
        gating_output: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        num_expert_group: int,
        topk_group: int,
        need_renorm: bool,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
    ) -> None:
        torch.ops.vllm.rocm_aiter_grouped_topk(
            gating_output,
            topk_weights,
            topk_ids,
            num_expert_group,
            topk_group,
            need_renorm,
            scoring_func,
            routed_scaling_factor,
        )

1753
1754
1755
1756
1757
1758
1759
1760
1761
    @staticmethod
    def fused_topk(
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        gate_up: bool,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.ops.vllm.rocm_aiter_fused_topk(x, router_logits, top_k, gate_up)

1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
    @staticmethod
    def mla_decode_fwd(
        q: torch.Tensor,
        kv_buffer: torch.Tensor,
        o: torch.Tensor,
        sm_scale: float,
        qo_indptr: torch.Tensor,
        max_seqlen_qo: int,
        kv_indptr: torch.Tensor | None = None,
        kv_indices: torch.Tensor | None = None,
        kv_last_page_lens: torch.Tensor | None = None,
        logit_cap: float = 0.0,
1774
1775
        q_scale: torch.Tensor | None = None,
        kv_scale: torch.Tensor | None = None,
1776
1777
1778
1779
1780
1781
        work_meta_data: torch.Tensor | None = None,
        work_indptr: torch.Tensor | None = None,
        work_info_set: torch.Tensor | None = None,
        reduce_indptr: torch.Tensor | None = None,
        reduce_final_map: torch.Tensor | None = None,
        reduce_partial_map: torch.Tensor | None = None,
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
    ):
        torch.ops.vllm.rocm_aiter_mla_decode_fwd(
            q,
            kv_buffer.view(-1, 1, 1, q.shape[-1]),
            o,
            qo_indptr,
            max_seqlen_qo,
            kv_indptr,
            kv_indices,
            kv_last_page_lens,
            sm_scale=sm_scale,
            logit_cap=logit_cap,
1794
1795
            q_scale=q_scale,
            kv_scale=kv_scale,
1796
1797
1798
1799
1800
1801
            work_meta_data=work_meta_data,
            work_indptr=work_indptr,
            work_info_set=work_info_set,
            reduce_indptr=reduce_indptr,
            reduce_final_map=reduce_final_map,
            reduce_partial_map=reduce_partial_map,
1802
1803
        )

vllmellm's avatar
vllmellm committed
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
    @staticmethod
    def per_tensor_quant(
        x: torch.Tensor,
        quant_dtype: torch.dtype,
        scale: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale)

    @staticmethod
    def per_token_quant(
        x: torch.Tensor,
        quant_dtype: torch.dtype,
        scale: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)

1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
    @staticmethod
    def gemm_a8wfp4(
        x: torch.Tensor,
        w: torch.Tensor,
        x_scales: torch.Tensor,
        w_scales: torch.Tensor,
        out_dtype: torch.dtype,
    ) -> torch.Tensor:
        return torch.ops.vllm.rocm_aiter_gemm_a8wfp4(
            x, w, x_scales, w_scales, out_dtype
        )

1832
    @staticmethod
1833
    def triton_fp4_gemm_dynamic_quant(
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
        x: torch.Tensor,
        weight: torch.Tensor,
        weight_scale: torch.Tensor,
        out_dtype: torch.dtype | None = torch.bfloat16,
        x_scales: torch.Tensor | None = None,
    ) -> torch.Tensor:
        from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
        from aiter.ops.triton.quant import dynamic_mxfp4_quant

        if x_scales is None:
            x_q, x_s = dynamic_mxfp4_quant(x)
        else:
            x_q = x
            x_s = x_scales

        y = torch.empty(
            x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
        )

        gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
        return y

1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
    @staticmethod
    def triton_rope_and_cache(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        positions: torch.Tensor,
        cos_sin_cache: torch.Tensor,
        is_neox: bool,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        layer_slot_mapping: torch.Tensor,
        k_scale: torch.Tensor,
        v_scale: torch.Tensor,
        flash_layout: bool,
        apply_scale: bool,
    ):
        from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache

        cos, sin = cos_sin_cache.chunk(2, dim=-1)
        fused_qk_rope_reshape_and_cache(
            query,
            key,
            value,
            key_cache,
            value_cache,
            layer_slot_mapping,
            positions,
            cos,
            sin,
            k_scale,
            v_scale,
            is_neox,
            flash_layout=flash_layout,
            apply_scale=apply_scale,
            q_out=query,
            k_out=key,
            output_zeros=False,
        )

1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
    @staticmethod
    def batched_gemm_a16wfp4(
        X: torch.Tensor,
        W: torch.Tensor,
        w_scale: torch.Tensor,
        Y: torch.Tensor,
        transpose_bm: bool | None = False,
        prequant: bool | None = False,
        y_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # ruff: noqa: E501 # isort: skip
        from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4

        return batched_gemm_a16wfp4(
            X,
            W,
            w_scale,
            y=Y,
            transpose_bm=transpose_bm,
            prequant=prequant,
            y_scale=y_scale,
        )

1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
    @staticmethod
    def triton_fp8_bmm(
        X: torch.Tensor,
        WQ: torch.Tensor,
        w_scale: torch.Tensor,
        group_size: int = 128,
        bias: torch.Tensor | None = None,
        dtype: torch.dtype | None = torch.bfloat16,
        splitK: int | None = None,
        YQ: torch.Tensor | None = None,
        transpose_bm: bool | None = False,
        config: dict | None = None,
    ) -> torch.Tensor:
        # ruff: noqa: E501 # isort: skip
        from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
            batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm,
        )

        return aiter_triton_fp8_bmm(
            X,
            WQ,
            w_scale,
            group_size=group_size,
            bias=bias,
            dtype=dtype,
            splitK=splitK,
            YQ=YQ,
            transpose_bm=transpose_bm,
            config=config,
        )

    @staticmethod
1950
    def group_fp8_quant(
1951
        input_2d: torch.Tensor,
1952
        group_size: int = 128,
1953
    ) -> tuple[torch.Tensor, torch.Tensor]:
1954
1955
        assert group_size == 128, "Group size must be 128"
        return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size)
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972

    @staticmethod
    def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
        return (n, k) in [
            (1024, 8192),
            (2112, 7168),
            (3072, 1536),
            (32768, 8192),
            (4096, 7168),
            (4608, 7168),
            (512, 7168),
            (7168, 2048),
            (7168, 256),
            (8192, 1024),
            (8192, 32768),
        ]

1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
    @staticmethod
    def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool:
        return (n, k) in [
            (8192, 4096),
            (1280, 8192),
            (16384, 53248),
            (106496, 16384),
            (57344, 8192),
            (8192, 2048),
            (2560, 8192),
            (10240, 8192),
            (16384, 16384),
            (8192, 28672),
            (28672, 8192),
            (18432, 16384),
            (8192, 1024),
            (7168, 8192),
            (5120, 8192),
            (8192, 8192),
            (8192, 7168),
            (14336, 8192),
            (8192, 14336),
            (8192, 3584),
        ]

1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
    @staticmethod
    def is_shuffled_per_token_w8a8_gemm_tuned(
        N: int, K: int, q_dtype_w: torch.dtype
    ) -> bool:
        import aiter.ops.gemm_op_a8w8 as aiter_gemm_a8w8_ops

        csv_path = (
            aiter_gemm_a8w8_ops.AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE
        )
        return _check_kernel_tuned(N, K, q_dtype_w, csv_path)

    @staticmethod
    def is_per_token_w8a8_gemm_tuned(N: int, K: int, q_dtype_w: torch.dtype) -> bool:
        import aiter.ops.gemm_op_a8w8 as aiter_gemm_a8w8_ops

        csv_path = aiter_gemm_a8w8_ops.AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_FILE
        return _check_kernel_tuned(N, K, q_dtype_w, csv_path)

2016
2017
    @staticmethod
    def shuffle_weight(
2018
        tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
2019
2020
2021
2022
2023
    ) -> torch.Tensor:
        from aiter.ops.shuffle import shuffle_weight

        return shuffle_weight(tensor, layout=layout)

2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
    @staticmethod
    def shuffle_weight_a16w4(
        tensor: "torch.Tensor",
        nLane: int,
        gate_up: bool,
    ) -> "torch.Tensor":
        """
        Shuffles the weight tensor into (A16W4) layout for AITER kernels.

        Args:
            tensor: The input weight tensor to be shuffled.
            layout: The block layout to use, defaults to (16, 4).

        Returns:
            torch.Tensor: The shuffled tensor.
        """
        from aiter.ops.shuffle import shuffle_weight_a16w4

        return shuffle_weight_a16w4(tensor, nLane, gate_up)

    @staticmethod
    def shuffle_scale_a16w4(
        tensor: "torch.Tensor",
        num_experts: int,
        gate_up: bool,
    ) -> "torch.Tensor":
        """
        Shuffles the scale tensor into (A16W4) layout for AITER kernels.

        Args:
            tensor: The input scale tensor to be shuffled.
            num_experts: Number of experts, needed for reshaping logic.
            gate_up: Whether the scale is for w13 (True) or w2 (False).

        Returns:
            torch.Tensor: The shuffled scale tensor.
        """
        from aiter.ops.shuffle import shuffle_scale_a16w4

        return shuffle_scale_a16w4(tensor, num_experts, gate_up)

2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
    @staticmethod
    def shuffle_weights(
        *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
    ) -> tuple[torch.Tensor, ...]:
        """
        Applies shuffle_weight function from AITER to each
        input tensor and returns them.

        Rearranges (shuffles) the input tensor/s
        into a specified block layout for optimized computation.

        Args:
            *tensors: Variable number of torch.Tensor objects.
            layout: A pair of integers specifying the block sizes used to divide
                the tensors during shuffling. Default is (16, 16).

        Returns:
        A Tuple of shuffled tensors.
        """
        from aiter.ops.shuffle import shuffle_weight

        return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)

2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
    @staticmethod
    def flash_attn_varlen_func(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        cu_seqlens_q: torch.Tensor,
        cu_seqlens_k: torch.Tensor,
        max_seqlen_q: int,
        max_seqlen_k: int,
        min_seqlen_q: int | None = None,
        dropout_p: float = 0.0,
        softmax_scale: float | None = None,
        causal: bool = False,
        window_size: tuple[int, int] | None = None,
        alibi_slopes: torch.Tensor | None = None,
        return_lse: bool = False,
        out: torch.Tensor | None = None,
    ):
        """
        Flash attention with variable length sequences.

        This function is NOT wrapped with @is_aiter_supported decorator
        to allow explicit backend selection via attention_config to work
        even when VLLM_ROCM_USE_AITER=0.

        Note: This performs lazy import of aiter.flash_attn_varlen_func
        """
        from aiter import flash_attn_varlen_func

        return flash_attn_varlen_func(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=max_seqlen_k,
            min_seqlen_q=min_seqlen_q,
            dropout_p=dropout_p,
            softmax_scale=softmax_scale,
            causal=causal,
            window_size=window_size,
            alibi_slopes=alibi_slopes,
            return_lse=return_lse,
            out=out,
        )

    @staticmethod
    def pa_fwd_asm(
        Q: torch.Tensor,
        K: torch.Tensor,
        V: torch.Tensor,
        block_tables: torch.Tensor,
        context_lens: torch.Tensor,
        block_tables_stride0: int,
        K_QScale: torch.Tensor,
        V_QScale: torch.Tensor,
        out_: torch.Tensor,
    ):
        """
        Paged attention forward pass using assembly kernel.

        This function is NOT wrapped with @is_aiter_supported decorator
        to allow explicit backend selection via attention_config to work
        even when VLLM_ROCM_USE_AITER=0.

        Note: This performs lazy import of aiter.pa_fwd_asm
        """
        from aiter import pa_fwd_asm

        return pa_fwd_asm(
            Q=Q,
            K=K,
            V=V,
            block_tables=block_tables,
            context_lens=context_lens,
            block_tables_stride0=block_tables_stride0,
            K_QScale=K_QScale,
            V_QScale=V_QScale,
            out_=out_,
        )

2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
    @staticmethod
    def paged_attention_common(
        Q: torch.Tensor,
        K: torch.Tensor,
        V: torch.Tensor,
        tmp_out: torch.Tensor,
        max_logits: torch.Tensor,
        exp_sums: torch.Tensor,
        max_seq_len: int,
        block_tables: torch.Tensor,
        context_lens: torch.Tensor,
        block_tables_stride0: int,
        scale: float,
        K_QScale_hip: torch.Tensor,
        V_QScale_hip: torch.Tensor,
        K_QScale_asm: torch.Tensor,
        V_QScale_asm: torch.Tensor,
        out_: torch.Tensor,
        kv_cache_dtype: str,
    ):
        """
        Paged attention common function.

        This function is NOT wrapped with @is_aiter_supported decorator
        to allow explicit backend selection via attention_config to work
        even when VLLM_ROCM_USE_AITER=0.

        Note: This performs lazy import of aiter.paged_attention_common
        """
        from aiter import paged_attention_common

        return paged_attention_common(
            Q=Q,
            K=K,
            V=V,
            tmp_out=tmp_out,
            max_logits=max_logits,
            exp_sums=exp_sums,
            max_seq_len=max_seq_len,
            block_tables=block_tables,
            context_lens=context_lens,
            block_tables_stride0=block_tables_stride0,
            scale=scale,
            K_QScale_hip=K_QScale_hip,
            V_QScale_hip=V_QScale_hip,
            K_QScale_asm=K_QScale_asm,
            V_QScale_asm=V_QScale_asm,
            out_=out_,
            kv_cache_dtype=kv_cache_dtype,
        )

2221

2222
rocm_aiter_ops.register_ops_once()