"lm_eval/tasks/logiqa2/logiqa2.yaml" did not exist on "4e44f0aac16c56adb127620da6b961c15e2a0ed0"
triton-lang_triton_6570_lazy_init.patch 2.64 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
From 7240b92457a723a3a3ec2292e40df6274382524c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?oliver=20k=C3=B6nig?= <okoenig@nvidia.com>
Date: Wed, 11 Jun 2025 19:13:30 +0000
Subject: [PATCH] f
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: oliver könig <okoenig@nvidia.com>
---
 external/patches/main.py | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py b/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py
index 69305dc94..4600542b8 100644
--- a/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py
+++ b/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py
@@ -4,6 +4,7 @@ import builtins
 import os
 import time
 import inspect
+from functools import cached_property
 from typing import Dict, Tuple, List, Optional
 
 from .jit import KernelInterface
@@ -94,6 +95,7 @@ class Autotuner(KernelInterface):
         while not inspect.isfunction(self.base_fn):
             self.base_fn = self.base_fn.fn
 
+        self._do_bench = do_bench
         self.num_warmups = warmup
         self.num_reps = rep
         self.use_cuda_graph = use_cuda_graph
@@ -107,7 +109,7 @@ class Autotuner(KernelInterface):
                           stacklevel=1)
             if use_cuda_graph:
                 from ..testing import do_bench_cudagraph
-                self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
+                self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
                     kernel_call,
                     rep=rep if rep is not None else 100,
                     quantiles=quantiles,
@@ -115,7 +117,7 @@ class Autotuner(KernelInterface):
                 return
 
             import triton.testing
-            self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
+            self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
                 kernel_call,
                 warmup=warmup if warmup is not None else 25,
                 rep=rep if rep is not None else 100,
@@ -123,10 +125,11 @@ class Autotuner(KernelInterface):
             )
             return
 
-        if do_bench is None:
-            self.do_bench = driver.active.get_benchmarker()
-        else:
-            self.do_bench = do_bench
+    @cached_property
+    def do_bench(self):
+        if self._do_bench is None:
+            return driver.active.get_benchmarker()
+        return self._do_bench
 
     def _bench(self, *args, config, **meta):
         from ..compiler.errors import CompileTimeAssertionFailure
-- 
2.43.0