"driver/vscode:/vscode.git/clone" did not exist on "1c4ef23cff46f627ea22c8e2afc68218017f2523"
Unverified Commit 7a5077e4 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Transform] Migrate `LowerIntrin` from tvm into tilelang (#999)

* Donot lower ceildiv to >>

* lint fix

* test fix

* fallback ceildiv changes
parent eb37e459
...@@ -11,10 +11,10 @@ export PYTHONPATH=$ROOT_DIR:$PYTHONPATH ...@@ -11,10 +11,10 @@ export PYTHONPATH=$ROOT_DIR:$PYTHONPATH
# Run pytest in parallel (4 workers) for all tests in the examples directory # Run pytest in parallel (4 workers) for all tests in the examples directory
cd examples cd examples
python -m pytest -n 4 . python -m pytest -n 4 . --verbose --color=yes --durations=0 --showlocals --cache-clear
cd .. cd ..
# Run pytest in parallel (4 workers) for all tests in the testing/python directory # Run pytest in parallel (4 workers) for all tests in the testing/python directory
cd testing/python cd testing/python
python -m pytest -n 4 . python -m pytest -n 4 . --verbose --color=yes --durations=0 --showlocals --cache-clear
cd .. cd ..
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Lower intrinsic calls and ops to device specific ir when possible.
* \file lower_intrin.cc
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <limits>
#include <unordered_set>
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/pattern_match.h"
namespace tvm {
namespace tl {
using namespace tir;
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt_;
using FLowerGeneral = ffi::TypedFunction<PrimExpr(PrimExpr)>;
IntrinInjecter(arith::Analyzer *analyzer, std::string target,
std::string mtriple = "")
: IRMutatorWithAnalyzer(analyzer) {
std::vector<std::string> patterns;
patterns.push_back(target + ".FLowerIntrinsic");
patterns.push_back(target + ".FLegalize");
bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
if (is_llvm_aarch64) {
patterns.push_back(target + ".aarch64.FLowerIntrinsic");
patterns.push_back(target + ".aarch64.FLegalize");
}
patterns.push_back("default.FLowerIntrinsic");
patterns.push_back("default.FLegalize");
for (const std::string &pattern : patterns)
if (Op::HasAttrMap(pattern)) {
attr_maps_.push_back(Op::GetAttrMap<FLowerGeneral>(pattern));
if (fma_ == nullptr) {
fma_ = (*attr_maps_.rbegin()).get(Op::Get("tir.fma"), nullptr);
}
}
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (auto *ptr_op = op->op.as<OpNode>()) {
for (const auto &f_attr_map : attr_maps_) {
FLowerGeneral f = f_attr_map.get(GetRef<Op>(ptr_op), nullptr);
if (f != nullptr) {
PrimExpr e = GetRef<PrimExpr>(op);
PrimExpr r = f(e);
ICHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) {
r = this->VisitExpr(r);
if (r.defined()) {
return r;
}
}
}
}
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
PrimExpr VisitExpr_(const AddNode *op) final {
if (const MulNode *mb = op->b.as<MulNode>()) {
return MakeFMA(mb->a, mb->b, op->a, op);
} else if (const MulNode *ma = op->a.as<MulNode>()) {
return MakeFMA(ma->a, ma->b, op->b, op);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
PrimExpr VisitExpr_(const FloorDivNode *op) final {
auto e = GetRef<PrimExpr>(op);
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDivNode>();
if (op == nullptr)
return ret;
int shift;
const DataType &dtype = op->dtype;
ICHECK(dtype.is_int() || dtype.is_uint());
// lower (a + 31) // 512 to (a + 31) >> 5
if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
// lower to right shift if possible.
return op->a >> make_const(dtype, shift);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
// Common path, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
analyzer_->CanProveGreaterEqual(e, 0)) {
return truncdiv(op->a, op->b);
}
// If the numerator's lower bound is known, express the floordiv
// in terms of truncdiv using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value < 0 &&
const_int_bound->min_value >
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))
->value)) {
// The goal is to write floordiv(a,b) in terms of truncdiv, without
// using negative operands.
//
// For any integer c
//
// floordiv(a,b) == floordiv(a + b*c - b*c, b)
// == floordiv(a + b*c, b) - c
//
// Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of
// truncdiv as follows.
//
// c == ceildiv(-a_min,b)
// == floordiv(-a_min + (b-1), b)
// == truncdiv(-a_min + (b-1), b)
//
// When substituted into `a + b*c`, this results in a positive argument.
//
// a + b*c
// == a + b*ceildiv(-a_min,b)
// == a - b*floordiv(a_min,b)
// >= a - b*floordiv(a,b)
// == floormod(a, b)
// >= 0
//
// Since the argument is positive, this allows floordiv to be written as
// followed.
//
// floordiv(a,b)
// == floordiv(a + b*c, b) - c
// == truncdiv(a + b*c, b) - c
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
// Skip analyzer simplification so we preserve straightforward div
// expressions.
PrimExpr offset_numerator = op->a + op->b * ceildiv;
return truncdiv(offset_numerator, op->b) - ceildiv;
}
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) &&
support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1));
}
} else {
if (dtype.is_float()) {
// floor(a / b)
return VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>());
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
auto rmod = tir::Var("rmod", dtype);
auto rdiv = tir::Var("rdiv", dtype);
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
PrimExpr let_rdiv = tir::Let(
rdiv, truncdiv(op->a, op->b),
tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
rdiv, rdiv - make_const(dtype, 1)));
return Let(rmod, truncmod(op->a, op->b), let_rdiv);
}
}
}
PrimExpr VisitExpr_(const FloorModNode *op) final {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorModNode>();
if (op == nullptr)
return ret;
// Lower floordiv to native truncdiv.
int shift;
const DataType &dtype = op->dtype;
ICHECK(dtype.is_int() || dtype.is_uint());
if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
// lower to masking if possible.
int64_t mask =
(static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
return op->a & make_const(dtype, mask);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
// Common pass, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
return truncmod(op->a, op->b);
}
// If the numerator's lower bound is known, express the floormod
// in terms of truncmod using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value < 0 &&
const_int_bound->min_value >
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))
->value)) {
// The goal is to write floormod(a,b) in terms of truncdiv and truncmod,
// without using negative operands.
//
// For any integer c
//
// floormod(a, b) == floormod(a + b*c, b)
//
// Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of
// truncdiv as follows.
//
// c == ceildiv(-a_min,b)
// == floordiv(-a_min + (b-1), b)
// == truncdiv(-a_min + (b-1), b)
//
// When substituted into `a + b*c`, this results in a positive argument.
//
// a + b*c
// == a + b*ceildiv(-a_min,b)
// == a - b*floordiv(a_min,b)
// >= a - b*floordiv(a,b)
// == floormod(a, b)
// >= 0
//
// Since the argument is positive, this allows floordiv to be written as
// followed.
//
// floormod(a,b)
// == floormod(a + b*c, b)
// == truncmod(a + b*c, b)
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
PrimExpr offset_numerator =
analyzer_->Simplify(op->a + op->b * ceildiv);
return truncmod(offset_numerator, op->b);
}
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
PrimExpr rmod = truncmod(op->a, op->b);
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) &&
support_bitwise_op_) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
} else {
return tir::Select(rmod >= 0, rmod, rmod + op->b);
}
} else {
if (dtype.is_float()) {
// a - floor(a / b) * b
return op->a -
(VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()) * op->b);
} else {
// uncommon case
DLOG(INFO)
<< "LowerFloorMod: Cannot decide the sign of divsor and divident";
auto rmod = tir::Var("rmod", dtype);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
return Let(rmod, truncmod(op->a, op->b),
Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
rmod, rmod + op->b));
}
}
}
PrimExpr VisitExpr_(const MaxNode *op) final {
using namespace arith;
PVar<PrimExpr> x, y;
PVar<IntImm> c;
auto e = GetRef<PrimExpr>(op);
if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
PrimExpr VisitExpr_(const EQNode *op) final {
using namespace arith;
PVar<PrimExpr> x, y;
auto e = GetRef<PrimExpr>(op);
if ((floormod(x, y) == 0).Match(e)) {
return VisitExpr((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
PrimExpr VisitExpr_(const NENode *op) final {
using namespace arith;
PVar<PrimExpr> x, y;
auto e = GetRef<PrimExpr>(op);
if ((floormod(x, y) != 0).Match(e)) {
return VisitExpr((truncmod(x, y) != 0).Eval());
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
private:
PrimExpr SwapBroadcastCast(const PrimExpr &e) {
// Try to change broadcast(cast(x)) to cast(broadcast(x))
// For some targets, LLVM will generate more efficient FMA
// instruction with the latter. For example, vmla vs. vmlal
// on ARM.
if (const BroadcastNode *bcast = e.as<BroadcastNode>()) {
if (const CastNode *cast = bcast->value.as<CastNode>()) {
auto should_swap = [&]() {
// Maintain behaviour (int8 -> int16, fp16 -> fp32).
if (cast->dtype.bits() == cast->value.dtype().bits() * 2) {
return true;
}
// Check both operands are integer-like.
if (!cast->dtype.is_uint() && !cast->dtype.is_int()) {
return false;
}
if (!cast->value.dtype().is_uint() && !cast->value.dtype().is_int()) {
return false;
}
// If both are integer-like, swap if we have a widening cast.
return cast->dtype.bits() > cast->value.dtype().bits();
};
if (should_swap()) {
PrimExpr new_bcast = Broadcast(cast->value, bcast->lanes);
return Cast(bcast->dtype, new_bcast);
}
}
}
return e;
}
PrimExpr MakeFMA(const PrimExpr &a, const PrimExpr &b, const PrimExpr &c,
const AddNode *op) {
// emit fma instruction: a * b + c
PrimExpr lhs = SwapBroadcastCast(a);
PrimExpr rhs = SwapBroadcastCast(b);
if (fma_ != nullptr && op->dtype.is_float()) {
PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
if (r.defined())
return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
PrimExpr mul = this->VisitExpr(Mul(lhs, rhs));
return Add(mul, this->VisitExpr(c));
}
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
// attribute maps, shared only when FLegalize == FLowerIntrinsic
std::vector<OpAttrMap<FLowerGeneral>> attr_maps_;
FLowerGeneral fma_{nullptr};
bool support_bitwise_op_{true};
};
Stmt LowerIntrinStmt(Stmt stmt, const std::string &target) {
arith::Analyzer analyzer;
return IntrinInjecter(&analyzer, target)(std::move(stmt));
}
namespace transform {
tir::transform::Pass LowerIntrin() {
using namespace tir::transform;
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "LowerIntrin: Require the target attribute";
arith::Analyzer analyzer;
auto mtriple = target.value()->GetAttr<String>("mtriple", "");
n->body = IntrinInjecter(&analyzer, target.value()->kind->name,
mtriple.value())(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerIntrin", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerIntrin", LowerIntrin);
});
} // namespace transform
} // namespace tl
} // namespace tvm
import tilelang.language as T
import tilelang.testing
import torch
@tilelang.jit(out_idx=[-1])
def _ceildiv_kernel(a: int, b: int):
@T.prim_func
def ceildiv_kernel(A: T.Tensor((1,), "int32")):
with T.Kernel(1, threads=1) as _:
A[0] = T.ceildiv(T.int32(a), T.int32(b))
return ceildiv_kernel
def run_ceildiv(a=128, b=32):
kernel = _ceildiv_kernel(a, b)
A = kernel()
print(kernel.get_kernel_source())
print(A)
def test_ceildiv():
run_ceildiv(a=128, b=32)
run_ceildiv(a=1, b=32)
run_ceildiv(a=-1, b=32)
run_ceildiv(a=-2, b=32)
@tilelang.jit
def _ceildiv_kernel_dyn(b: int):
@T.prim_func
def ceildiv_kernel(A: T.Tensor((1,), "int32"), a: T.int32):
with T.Kernel(1, threads=1) as _:
A[0] = T.ceildiv(T.int32(a), T.int32(b))
return ceildiv_kernel
def run_ceildiv_dyn(a=128, b=32):
kernel = _ceildiv_kernel_dyn(b)
A = torch.empty((1,), dtype=torch.int32, device="cuda")
kernel(A, a)
print(kernel.get_kernel_source())
print(A)
@tilelang.testing.requires_cuda
def test_ceildiv_dyn():
run_ceildiv_dyn(a=128, b=32)
run_ceildiv_dyn(a=1, b=32)
run_ceildiv_dyn(a=-1, b=32)
run_ceildiv_dyn(a=-2, b=32)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -138,7 +138,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule: ...@@ -138,7 +138,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule:
host_mod = tir.transform.BF16StorageLegalize()(host_mod) host_mod = tir.transform.BF16StorageLegalize()(host_mod)
host_mod = tir.transform.LowerTVMBuiltin()(host_mod) host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
host_mod = tir.transform.LowerCustomDatatypes()(host_mod) host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
host_mod = tir.transform.LowerIntrin()(host_mod) host_mod = tilelang.transform.LowerIntrin()(host_mod)
host_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(host_mod) host_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(host_mod)
host_mod = tir.transform.CombineContextCall()(host_mod) host_mod = tir.transform.CombineContextCall()(host_mod)
if target_host.kind.name == "llvm": if target_host.kind.name == "llvm":
...@@ -152,7 +152,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule: ...@@ -152,7 +152,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule:
def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod) device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod) device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod) device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda": if target.kind.name == "cuda":
...@@ -167,7 +167,7 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: ...@@ -167,7 +167,7 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod) device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod) device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod) device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda": if target.kind.name == "cuda":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")( device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(
......
...@@ -30,7 +30,7 @@ class BaseKernelAdapter(ABC): ...@@ -30,7 +30,7 @@ class BaseKernelAdapter(ABC):
result_idx = [result_idx] result_idx = [result_idx]
elif isinstance(result_idx, list): elif isinstance(result_idx, list):
for i, idx in enumerate(result_idx): for i, idx in enumerate(result_idx):
if idx >= len(params) or idx <= -len(params): if idx >= len(params) or idx < -len(params):
raise ValueError( raise ValueError(
f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}" f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}"
) )
......
...@@ -145,6 +145,12 @@ cdef class CythonKernelWrapper: ...@@ -145,6 +145,12 @@ cdef class CythonKernelWrapper:
if not tensor.is_contiguous(): if not tensor.is_contiguous():
raise ValueError(f"Expected parameter {param} to be a contiguous tensor") raise ValueError(f"Expected parameter {param} to be a contiguous tensor")
cdef object _infer_output_device(self, list inputs):
for tensor in inputs:
if isinstance(tensor, torch.Tensor):
return tensor.device
return torch.cuda.current_device()
cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False): cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False):
# Validate input dimensions and prepare for kernel execution # Validate input dimensions and prepare for kernel execution
cdef int total_params = len(self.params) cdef int total_params = len(self.params)
...@@ -170,6 +176,7 @@ cdef class CythonKernelWrapper: ...@@ -170,6 +176,7 @@ cdef class CythonKernelWrapper:
cdef int ins_idx = 0 cdef int ins_idx = 0
cdef list tensor_list = [] cdef list tensor_list = []
device = None
# Prepare input and output tensors # Prepare input and output tensors
for i in range(len(self.params)): for i in range(len(self.params)):
...@@ -185,7 +192,10 @@ cdef class CythonKernelWrapper: ...@@ -185,7 +192,10 @@ cdef class CythonKernelWrapper:
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
else: # Already converted to Python int during initialization else: # Already converted to Python int during initialization
shape.append(s) shape.append(s)
device = inputs[0].device if len(inputs) > 0 else torch.cuda.current_device()
if device is None:
device = self._infer_output_device(inputs)
if len(shape) == 0: if len(shape) == 0:
param_name = self.params[i].name if hasattr(self.params[i], 'name') else f'parameter_{i}' param_name = self.params[i].name if hasattr(self.params[i], 'name') else f'parameter_{i}'
raise ValueError( raise ValueError(
...@@ -263,4 +273,4 @@ cdef class CythonKernelWrapper: ...@@ -263,4 +273,4 @@ cdef class CythonKernelWrapper:
return tensor_list[self.result_idx[0]] return tensor_list[self.result_idx[0]]
else: else:
return [tensor_list[i] for i in self.result_idx] return [tensor_list[i] for i in self.result_idx]
\ No newline at end of file
...@@ -438,6 +438,12 @@ def LowerThreadAllreduce(): ...@@ -438,6 +438,12 @@ def LowerThreadAllreduce():
return _ffi_api.LowerThreadAllreduce() # type: ignore return _ffi_api.LowerThreadAllreduce() # type: ignore
def LowerIntrin():
"""LowerIntrin
"""
return _ffi_api.LowerIntrin() # type: ignore
def LowerDeviceKernelLaunch(): def LowerDeviceKernelLaunch():
""" """
Create and return a transform pass that lowers device kernel launch constructs to target-specific IR. Create and return a transform pass that lowers device kernel launch constructs to target-specific IR.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment