"docs/vscode:/vscode.git/clone" did not exist on "5df75c33cb543cb55e7a1616bfa2a4a3416243b8"
Unverified Commit 6e994b12 authored by meinie's avatar meinie Committed by GitHub
Browse files

[Bugfix] Assign Target for jit kernel (#648)



* fix: Copy Target to self.target

* refactor: Remove unused target attribute and adjust context management in JITKernel

- Removed the unused `target` attribute from the `JITKernel` class.
- Updated the context management in the `compile` method to utilize `self.target`, improving clarity and ensuring proper resource handling during compilation.

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent fec9b930
...@@ -75,7 +75,6 @@ class JITKernel(object): ...@@ -75,7 +75,6 @@ class JITKernel(object):
""" """
self.prim_func = func self.prim_func = func
self.execution_backend = execution_backend self.execution_backend = execution_backend
self.target = target
self.target_host = target_host self.target_host = target_host
self.verbose = verbose self.verbose = verbose
...@@ -89,7 +88,7 @@ class JITKernel(object): ...@@ -89,7 +88,7 @@ class JITKernel(object):
target = determine_target(target) target = determine_target(target)
# Ensure the target is always a TVM Target object. # Ensure the target is always a TVM Target object.
target = Target(target) self.target = Target(target)
# Validate the execution backend. # Validate the execution backend.
assert execution_backend in [ assert execution_backend in [
...@@ -196,7 +195,7 @@ class JITKernel(object): ...@@ -196,7 +195,7 @@ class JITKernel(object):
# Compile the function with TVM, optimizing with shared memory lowering. # Compile the function with TVM, optimizing with shared memory lowering.
enable_host_codegen = execution_backend == "dlpack" enable_host_codegen = execution_backend == "dlpack"
enable_device_compile = execution_backend == "dlpack" enable_device_compile = execution_backend == "dlpack"
with tvm.transform.PassContext(opt_level=3, config=pass_configs): with tvm.transform.PassContext(opt_level=3, config=pass_configs), self.target:
artifact = tilelang.lower( artifact = tilelang.lower(
tilelang_func, tilelang_func,
target=target, target=target,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment