Unverified Commit f8d3e73e authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Bugfix] Fix missing reg alloc in custom warp specialization (#1084)

parent bc37ea69
...@@ -95,8 +95,7 @@ public: ...@@ -95,8 +95,7 @@ public:
private: private:
Stmt VisitStmt_(const EvaluateNode *op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) { if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(set_max_nreg()) || if (call->op.same_as(no_set_max_nreg())) {
call->op.same_as(no_set_max_nreg())) {
// Remove the original set_max_nreg calls as they will be re-inserted // Remove the original set_max_nreg calls as they will be re-inserted
// at appropriate locations // at appropriate locations
return Evaluate(0); return Evaluate(0);
...@@ -136,11 +135,9 @@ private: ...@@ -136,11 +135,9 @@ private:
// Only inject if we have valid register hints and no SIMT copy // Only inject if we have valid register hints and no SIMT copy
bool has_simt_copy = SimtCopyDetector::Detect(producer_body); bool has_simt_copy = SimtCopyDetector::Detect(producer_body);
if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) { if (dec_reg == 0 && inc_reg == 0 && !has_simt_copy) {
auto inc_reg_num = auto inc_reg_num = IntImm(DataType::Int(32), 240);
IntImm(DataType::Int(32), inc_reg == 0 ? 240 : inc_reg); auto dec_reg_num = IntImm(DataType::Int(32), 24);
auto dec_reg_num =
IntImm(DataType::Int(32), dec_reg == 0 ? 24 : dec_reg);
inc_reg_stmt = Evaluate( inc_reg_stmt = Evaluate(
Call(DataType::Handle(), set_max_nreg(), {inc_reg_num, 1})); Call(DataType::Handle(), set_max_nreg(), {inc_reg_num, 1}));
dec_reg_stmt = Evaluate( dec_reg_stmt = Evaluate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment