_custom_ops.py 21.3 KB
Newer Older
1
import contextlib
2
import functools
3
from typing import List, Optional, Tuple, Union
4
5
6

import torch

7
from vllm._core_ext import ScalarType
8
9
10
11
from vllm.logger import init_logger

logger = init_logger(__name__)

12
try:
13
    import vllm._C
14
15
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)
16

17
18
with contextlib.suppress(ImportError):
    # ruff: noqa: F401
19
    import vllm._moe_C
20
21
22
23
24
25


def is_custom_op_supported(op_name: str) -> bool:
    op, overloads = torch._C._jit_get_operation(op_name)
    return op is not None

26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def hint_on_error(fn):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except AttributeError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Possibly you have built or installed an obsolete version of vllm.\n"
                "Please try a clean build and install of vllm,"
                "or remove old built files such as vllm/*cpython*.so and build/ ."
            )
            logger.error(msg, fn.__name__, e)
            raise e

    return wrapper


46
47
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
48
    torch.ops._C.silu_and_mul(out, x)
49
50
51


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
52
    torch.ops._C.gelu_and_mul(out, x)
53
54
55


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
56
    torch.ops._C.gelu_tanh_and_mul(out, x)
57
58
59


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
60
    torch.ops._C.gelu_fast(out, x)
61
62
63


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
64
    torch.ops._C.gelu_new(out, x)
65
66


67
68
69
70
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_quick(out, x)


71
72
73
74
75
76
77
78
79
# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
80
    seq_lens: torch.Tensor,
81
    block_size: int,
82
    max_seq_len: int,
83
84
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
85
86
    k_scale: float,
    v_scale: float,
87
88
89
90
91
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
92
) -> None:
93
    torch.ops._C.paged_attention_v1(
94
95
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
96
97
98
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
99
100
101
102
103
104
105
106
107
108
109
110
111


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
112
    seq_lens: torch.Tensor,
113
    block_size: int,
114
    max_seq_len: int,
115
116
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
117
118
    k_scale: float,
    v_scale: float,
119
120
121
122
123
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
124
) -> None:
125
    torch.ops._C.paged_attention_v2(
126
127
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
128
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
129
130
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
131
132
133
134
135
136
137
138
139
140
141


# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
142
143
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
144
145
146
147
148
149
150


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
151
152
153
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
154
155
156
157
158


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
159
    torch.ops._C.rms_norm(out, input, weight, epsilon)
160
161
162
163


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
164
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
165
166


167
168
169
170
171
172
173
174
175
176
177
178
def advance_step(num_seqs: int, num_queries: int, block_size: int,
                 input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
                 input_positions: torch.Tensor, seq_lens: torch.Tensor,
                 slot_mapping: torch.Tensor,
                 block_tables: torch.Tensor) -> None:
    """Advance a step on GPU for existing inputs for a multi-step runner"""
    return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
                                     input_tokens, sampled_token_ids,
                                     input_positions, seq_lens, slot_mapping,
                                     block_tables)


179
180
181
182
183
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
184
185
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
186
187
188
189


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
190
    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
191
192
193
194
195
196
197


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
198
199
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)
200
201
202
203


def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
204
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
205
206
207
208
209


# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
                    lookup_table: torch.Tensor) -> None:
210
    torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
211
212
213
214
215
216


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
217
218
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
219
220


221
222
223
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
224
225
                        workspace: torch.Tensor, b_q_type: ScalarType,
                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
226
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
227
                                            workspace, b_q_type, size_m,
228
                                            size_n, size_k)
229
230


231
# cutlass
232
233
234
235
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


236
237
238
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
239
                      scale_b: torch.Tensor,
240
                      out_dtype: torch.dtype,
241
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
242
243
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
244
245
    assert bias is None or bias.shape[0] == b.shape[
        1] and bias.dtype == out_dtype
246
247
248
249
250

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

251
252
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

253
254
255
    return out


256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def cutlass_scaled_mm_azp(a: torch.Tensor,
                          b: torch.Tensor,
                          scale_a: torch.Tensor,
                          scale_b: torch.Tensor,
                          out_dtype: torch.dtype,
                          azp_adj: torch.Tensor,
                          azp: Optional[torch.Tensor] = None,
                          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.numel(
    ) == b.shape[1] and bias.dtype == out_dtype

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

    torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
                                       azp, bias)
    return out


278
279
280
281
282
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
              codebook_partition_sizes: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
283
284
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
285
286
287
288


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
                 codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
289
290
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
291
292


293
294
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
295
296
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
297
298
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
299
300


301
302
303
304
305
306
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


307
308
309
310
311
312
313
314
315
316
317
318
319
320
def gptq_marlin_gemm(a: torch.Tensor,
                     b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor,
                     b_zeros: torch.Tensor,
                     g_idx: torch.Tensor,
                     perm: torch.Tensor,
                     workspace: torch.Tensor,
                     b_q_type: ScalarType,
                     size_m: int,
                     size_n: int,
                     size_k: int,
                     is_k_full: bool,
                     has_zp: bool = False,
                     use_fp32_reduce: bool = False) -> torch.Tensor:
321
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
322
                                         g_idx, perm, workspace, b_q_type,
323
                                         size_m, size_n, size_k, is_k_full,
324
                                         has_zp, use_fp32_reduce)
325
326


327
328
329
330
331
332
333
334
335
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    b_scales: torch.Tensor, workspace: torch.Tensor,
                    num_bits: int, size_m: int, size_n: int,
                    size_k: int) -> torch.Tensor:
    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
                                        num_bits, size_m, size_n, size_k)


336
# fp8
337
338
339
def scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
340
    num_token_padding: Optional[int] = None,
341
    scale_ub: Optional[torch.Tensor] = None,
342
    use_per_token_if_dynamic: bool = False,
343
) -> Tuple[torch.Tensor, torch.Tensor]:
344
345
346
347
348
349
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
350
    optional padding of the output tensors for downstream kernels that
351
352
353
354
355
    will benefit from padding.

    Args:
        input: The input tensor to be quantized to FP8
        scale: Optional scaling factor for the FP8 quantization
356
357
        scale_ub: Optional upper bound for scaling factor in dynamic 
            per token case
358
        num_token_padding: If specified, pad the first dimension
359
            of the output to at least this value.
360
361
        use_per_token_if_dynamic: Whether to do per_tensor or per_token 
            in the dynamic quantization case.
362
363
364
365
366

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
            scaling factor.
    """
367
368
    # This code assumes batch_dim and num_tokens are flattened
    assert (input.ndim == 2)
369
    shape: Union[Tuple[int, int], torch.Size] = input.shape
370
371
372
373
    if num_token_padding:
        shape = (max(num_token_padding, input.shape[0]), shape[1])
    output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)

374
    if scale is None:
375
        if use_per_token_if_dynamic:
376
            scale = torch.empty((shape[0], 1),
377
378
379
                                device=input.device,
                                dtype=torch.float32)
            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
380
                output, input, scale, scale_ub)
381
382
383
        else:
            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
384
    else:
385
386
        # num_token_padding not implemented for this case
        assert (scale.numel() == 1 or num_token_padding is None)
387
        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
388

389
    return output, scale
390
391


392
# int8
393
394
395
396
def scaled_int8_quant(
        input: torch.Tensor,
        scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
397
    """
398
    Quantize the input tensor to int8 and return the quantized tensor and scale.
399
400
401

    Args:
        input: The input tensor to be quantized to int8.
402
403
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
404
405

    Returns:
406
      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
407
    """
408
409
410
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
411
        torch.ops._C.static_scaled_int8_quant(output, input, scale)
412
413
414
415
416
417
        return output, scale

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
418
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
419
    return output, input_scales
420
421


422
423
424
425
426
427
428
429
430
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    s_tok: torch.Tensor, s_ch: torch.Tensor,
                    s_group: torch.Tensor, workspace: torch.Tensor,
                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
                                        workspace, size_m, size_n, size_k)


431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
# gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int):
    return torch.ops._C.ggml_dequantize(W, quant_type, m, n)


def ggml_mul_mat_vec(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
):
    return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row)


def ggml_mul_mat_vec_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
):
    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
):
    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


463
464
465
466
467
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
468
469
470
471
472
473
474
475
476
477
    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
                                      sorted_token_ids, experts_ids,
                                      num_tokens_post_pad)


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
478
479
480
481
482
483
484
485
486


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
487
488
    k_scale: float,
    v_scale: float,
489
) -> None:
490
491
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
492
                                             kv_cache_dtype, k_scale, v_scale)
493
494


495
496
497
498
499
500
501
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
502
503
    k_scale: float,
    v_scale: float,
504
) -> None:
505
506
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
507
508
                                                   kv_cache_dtype, k_scale,
                                                   v_scale)
509
510


511
512
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
513
                block_mapping: torch.Tensor) -> None:
514
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
515
516
517


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
518
                block_mapping: torch.Tensor) -> None:
519
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
520
521


522
523
524
525
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
                   handles: List[str], offsets: List[int], rank: int,
                   full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
                                                 offsets, rank, full_nvlink)


def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
                     full_nvlink: bool) -> bool:
    return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
                                                   full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

556

557
558
559
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
                     out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
560

561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
                    offsets: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(fa: int, handles: List[str],
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)


584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
    # find functions that are defined in this file and have torch.Tensor
    # in their annotations. `arg == "torch.Tensor"` is used to handle
    # the case when users use `import __annotations__` to turn type
    # hints into strings.
    if isinstance(v, fn_type) \
        and v.__code__.co_filename == __file__ \
        and any(arg is torch.Tensor or arg == "torch.Tensor"
599
                for arg in v.__annotations__.values()):
600
601
602
603
        names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type