"examples/vscode:/vscode.git/clone" did not exist on "bc3c73ad0b75ee550fdcce6e124d5a222834d6ed"
Unverified Commit 7ffc5b44 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Cache] Introduce detailed target information for the disk kernel cache (#780)

* Fix type hint for target_host parameter in compile function to allow None value

* Refactor target handling in compile function to utilize determine_target for improved clarity and consistency

* Update PrintConst function in codegen_cuda.cc to use hexfloat format for bfloat16 and float8/float4 types, while adding scientific notation comments for clarity. This change enhances the representation of floating-point constants in the generated code.

* Refactor PrintType function in codegen_cuda.cc to remove unnecessary failure conditions for floating-point types with lane counts greater than 4. This change simplifies the logic and improves code clarity.

* Enhance benchmark_matmul.py to conditionally print Reference TFlops only if ref_latency is not None. Update param.py to ensure target is converted to string for consistency. Refactor tuner.py to utilize determine_target for improved clarity in target handling.

* Remove automatic commit and push step from AMD and NVIDIA CI workflows to streamline the process and avoid unnecessary commits.
parent cdc5d8d3
...@@ -61,11 +61,6 @@ jobs: ...@@ -61,11 +61,6 @@ jobs:
fi fi
rm -rf build rm -rf build
- name: Commit and Push Changes
uses: stefanzweifel/git-auto-commit-action@v5
with:
commit_message: "lint"
build-test-amd: build-test-amd:
runs-on: [self-hosted, amd, gpu] runs-on: [self-hosted, amd, gpu]
needs: format-check needs: format-check
......
...@@ -61,11 +61,6 @@ jobs: ...@@ -61,11 +61,6 @@ jobs:
fi fi
rm -rf build rm -rf build
- name: Commit and Push Changes
uses: stefanzweifel/git-auto-commit-action@v5
with:
commit_message: "lint"
build-test-nvidia: build-test-nvidia:
runs-on: [self-hosted, nvidia] runs-on: [self-hosted, nvidia]
needs: format-check needs: format-check
......
...@@ -243,4 +243,5 @@ if __name__ == "__main__": ...@@ -243,4 +243,5 @@ if __name__ == "__main__":
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
if ref_latency is not None:
print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}") print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}")
...@@ -325,16 +325,12 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -325,16 +325,12 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
enable_fp6_ = true; enable_fp6_ = true;
if (t.lanes() <= 4) { if (t.lanes() <= 4) {
os << GetFP6Type(t); os << GetFP6Type(t);
} else {
fail = true;
} }
return; return;
} else if (t.is_float4()) { } else if (t.is_float4()) {
enable_fp4_ = true; enable_fp4_ = true;
if (t.lanes() <= 4) { if (t.lanes() <= 4) {
os << GetFP4Type(t); os << GetFP4Type(t);
} else {
fail = true;
} }
return; return;
} else if (t == DataType::Bool()) { } else if (t == DataType::Bool()) {
...@@ -1960,13 +1956,17 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os, ...@@ -1960,13 +1956,17 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
// Type code is kBFloat // Type code is kBFloat
if (op->dtype.is_bfloat16()) { if (op->dtype.is_bfloat16()) {
os << "bfloat16_t"; os << "bfloat16_t";
os << '(' << std::scientific << op->value << 'f' << ')'; os << '(' << std::hexfloat << op->value << 'f';
os << "/*" << std::scientific << op->value << "*/";
os << ')';
return; return;
} }
// Type code is kFloat8_e5m2 or kE4M4Float // Type code is kFloat8_e5m2 or kE4M4Float
if (op->dtype.is_float8() || op->dtype.is_float4()) { if (op->dtype.is_float8() || op->dtype.is_float4()) {
p->PrintType(op->dtype, os); p->PrintType(op->dtype, os);
os << '(' << std::scientific << op->value << 'f' << ')'; os << '(' << std::hexfloat << op->value << 'f';
os << "/*" << std::scientific << op->value << "*/";
os << ')';
return; return;
} }
// Type code is kFloat // Type code is kFloat
...@@ -1984,9 +1984,10 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os, ...@@ -1984,9 +1984,10 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
p->need_math_constants_h_ = true; p->need_math_constants_h_ = true;
} else { } else {
temp << std::scientific << op->value; temp << std::hexfloat << op->value;
if (op->dtype.bits() == 32) if (op->dtype.bits() == 32)
temp << 'f'; temp << 'f';
temp << "/*" << std::scientific << op->value << "*/";
} }
p->MarkConst(temp.str()); p->MarkConst(temp.str());
os << temp.str(); os << temp.str();
......
...@@ -68,7 +68,7 @@ class CompileArgs: ...@@ -68,7 +68,7 @@ class CompileArgs:
"execution_backend": "execution_backend":
self.execution_backend, self.execution_backend,
"target": "target":
self.target, str(self.target),
"target_host": "target_host":
str(self.target_host) if self.target_host else None, str(self.target_host) if self.target_host else None,
"verbose": "verbose":
......
...@@ -28,6 +28,7 @@ from pathlib import Path ...@@ -28,6 +28,7 @@ from pathlib import Path
from tilelang import env from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.autotuner.capture import get_autotune_inputs from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.utils.target import determine_target
from tilelang.jit.param import _P, _RProg from tilelang.jit.param import _P, _RProg
from tilelang.version import __version__ from tilelang.version import __version__
...@@ -150,7 +151,7 @@ class AutoTuner: ...@@ -150,7 +151,7 @@ class AutoTuner:
""" """
self.compile_args = CompileArgs( self.compile_args = CompileArgs(
out_idx=out_idx, out_idx=out_idx,
target=target, target=Target(determine_target(target)),
execution_backend=execution_backend, execution_backend=execution_backend,
target_host=target_host, target_host=target_host,
verbose=verbose, verbose=verbose,
......
...@@ -20,6 +20,7 @@ from tvm.tir import PrimFunc ...@@ -20,6 +20,7 @@ from tvm.tir import PrimFunc
from tvm.target import Target from tvm.target import Target
from tilelang.jit.kernel import JITKernel from tilelang.jit.kernel import JITKernel
from tilelang.utils.target import determine_target
from tilelang.cache import cached from tilelang.cache import cached
from os import path, makedirs from os import path, makedirs
from logging import getLogger from logging import getLogger
...@@ -34,7 +35,7 @@ def compile( ...@@ -34,7 +35,7 @@ def compile(
out_idx: Union[List[int], int, None] = None, out_idx: Union[List[int], int, None] = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None, target_host: Union[str, Target, None] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[Union[List[str], str]] = None, compile_flags: Optional[Union[List[str], str]] = None,
...@@ -69,6 +70,10 @@ def compile( ...@@ -69,6 +70,10 @@ def compile(
assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}" assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
if isinstance(compile_flags, str): if isinstance(compile_flags, str):
compile_flags = [compile_flags] compile_flags = [compile_flags]
# This path is not a performance critical path, so we can afford to convert the target.
target = Target(determine_target(target))
return cached( return cached(
func=func, func=func,
out_idx=out_idx, out_idx=out_idx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment