Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
558
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
369 additions
and
135 deletions
+369
-135
paddle/cinn/optim/replace_call_with_expr.cc
paddle/cinn/optim/replace_call_with_expr.cc
+4
-4
paddle/cinn/optim/replace_call_with_expr_test.cc
paddle/cinn/optim/replace_call_with_expr_test.cc
+1
-1
paddle/cinn/optim/replace_const_param_to_integer.cc
paddle/cinn/optim/replace_const_param_to_integer.cc
+1
-1
paddle/cinn/optim/replace_cross_thread_reduction.cc
paddle/cinn/optim/replace_cross_thread_reduction.cc
+189
-0
paddle/cinn/optim/replace_cross_thread_reduction.h
paddle/cinn/optim/replace_cross_thread_reduction.h
+33
-0
paddle/cinn/optim/replace_cross_thread_reduction_test.cc
paddle/cinn/optim/replace_cross_thread_reduction_test.cc
+85
-0
paddle/cinn/optim/replace_var_with_expr.cc
paddle/cinn/optim/replace_var_with_expr.cc
+3
-3
paddle/cinn/optim/tensor_write_tell.cc
paddle/cinn/optim/tensor_write_tell.cc
+0
-19
paddle/cinn/optim/tensor_write_tell.h
paddle/cinn/optim/tensor_write_tell.h
+0
-58
paddle/cinn/optim/transform_gpu_forloop.cc
paddle/cinn/optim/transform_gpu_forloop.cc
+8
-8
paddle/cinn/optim/transform_polyfor_to_for.cc
paddle/cinn/optim/transform_polyfor_to_for.cc
+3
-3
paddle/cinn/optim/unroll_loops.cc
paddle/cinn/optim/unroll_loops.cc
+5
-5
paddle/cinn/optim/var_mod_simplify.cc
paddle/cinn/optim/var_mod_simplify.cc
+2
-2
paddle/cinn/optim/vectorize_loops.cc
paddle/cinn/optim/vectorize_loops.cc
+27
-23
paddle/cinn/optim/vectorize_loops.h
paddle/cinn/optim/vectorize_loops.h
+1
-1
paddle/cinn/poly/ast_gen.cc
paddle/cinn/poly/ast_gen.cc
+1
-1
paddle/cinn/poly/ast_gen_test.cc
paddle/cinn/poly/ast_gen_test.cc
+1
-1
paddle/cinn/poly/dim.cc
paddle/cinn/poly/dim.cc
+1
-1
paddle/cinn/poly/domain.cc
paddle/cinn/poly/domain.cc
+3
-3
paddle/cinn/poly/domain_add_unit_loop_mutator.cc
paddle/cinn/poly/domain_add_unit_loop_mutator.cc
+1
-1
No files found.
Too many changes to show.
To preserve performance only
558 of 558+
files are displayed.
Plain diff
Email patch
paddle/cinn/optim/replace_call_with_expr.cc
View file @
01a10755
...
...
@@ -14,9 +14,9 @@
#include "paddle/cinn/optim/replace_call_with_expr.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
namespace
cinn
{
...
...
@@ -36,7 +36,7 @@ struct ReplaceCallWithExprModifier : public ir::IRMutator<> {
VLOG
(
3
)
<<
"Processing Call node "
<<
*
op
;
if
(
statement_
!=
node
->
name
)
return
;
Expr
expr_candidate
=
IRCopy
(
candidate_
);
Expr
expr_candidate
=
ir
::
ir_utils
::
IRCopy
(
candidate_
);
VLOG
(
3
)
<<
"Original candidate expr: "
<<
candidate_
;
VLOG
(
3
)
<<
"Copied candidate expr: "
<<
expr_candidate
;
...
...
@@ -62,7 +62,7 @@ void ReplaceIslCallWithExpr(Expr *e,
const
Expr
&
candidate
,
const
std
::
map
<
std
::
string
,
Expr
>
&
axis_map
)
{
VLOG
(
3
)
<<
"ReplaceCallWithExpr, original expression: "
<<
candidate
;
Expr
copied
=
IRCopy
(
candidate
);
Expr
copied
=
ir
::
ir_utils
::
IRCopy
(
candidate
);
// update the axis in the copied expression.
// we treat the Store node as the normal statement, the others like Call node
...
...
paddle/cinn/optim/replace_call_with_expr_test.cc
View file @
01a10755
...
...
@@ -17,8 +17,8 @@
#include <gtest/gtest.h>
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/poly/ast_gen.h"
...
...
paddle/cinn/optim/replace_const_param_to_integer.cc
View file @
01a10755
...
...
@@ -14,7 +14,7 @@
#include "paddle/cinn/optim/replace_const_param_to_integer.h"
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/poly/ast_gen.h"
#include "paddle/cinn/utils/string.h"
...
...
paddle/cinn/optim/replace_cross_thread_reduction.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* This file implements the strategy to remove the unnecessary nested block.
*/
#pragma once
#include "paddle/cinn/optim/replace_cross_thread_reduction.h"
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/hlir/pe/reduction.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/lang/compute.h"
namespace
cinn
{
namespace
optim
{
namespace
{
struct
BufferCmp
{
bool
operator
()(
const
ir
::
Buffer
&
a
,
const
ir
::
Buffer
&
b
)
const
{
if
(
a
->
name
==
b
->
name
)
return
false
;
return
true
;
}
};
thread_local
std
::
set
<
ir
::
Buffer
,
BufferCmp
>
shm_buffer_
;
struct
CrossThreadReductionReplacer
:
public
ir
::
IRMutator
<>
{
void
operator
()(
ir
::
Expr
*
expr
)
{
Visit
(
expr
);
}
private:
bool
CanReplace
(
const
ir
::
ScheduleBlockRealize
*
block_realize
)
{
const
ir
::
ScheduleBlock
*
schedule_block
=
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
CHECK_NOTNULL
(
schedule_block
);
if
(
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
.
substr
(
0
,
4
)
==
"root"
)
{
return
false
;
}
const
std
::
vector
<
ir
::
Expr
>&
iter_values
=
block_realize
->
iter_values
;
const
std
::
vector
<
ir
::
Var
>&
iter_vars
=
schedule_block
->
iter_vars
;
ir
::
Expr
body
=
schedule_block
->
body
;
std
::
unordered_set
<
std
::
string
>
reduce_var_names
;
for
(
int
i
=
0
;
i
<
iter_values
.
size
();
++
i
)
{
if
(
!
iter_vars
[
i
]
->
is_reduce_axis
)
{
continue
;
}
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
iter_values
[
i
],
[
&
](
const
ir
::
Expr
*
x
)
{
if
(
x
->
as_var
())
{
reduce_var_names
.
insert
(
x
->
as_var
()
->
name
);
}
return
false
;
});
}
std
::
vector
<
int
>
thread_binded_reduce_loop_indices
;
for
(
int
i
=
0
;
i
<
cur_loops_
.
size
();
++
i
)
{
if
(
reduce_var_names
.
count
(
cur_loops_
[
i
].
As
<
ir
::
For
>
()
->
loop_var
->
name
)
>
0
)
{
if
(
cur_loops_
[
i
].
As
<
ir
::
For
>
()
->
is_gpu_thread_binded
())
{
if
(
ir
::
GetLoopExtent
(
cur_loops_
[
i
])
>
1024
)
{
return
false
;
}
thread_binded_reduce_loop_indices
.
push_back
(
i
);
}
}
}
if
(
thread_binded_reduce_loop_indices
.
size
()
==
0
||
thread_binded_reduce_loop_indices
.
back
()
!=
cur_loops_
.
size
()
-
1
)
{
return
false
;
}
for
(
int
i
=
1
;
i
<
thread_binded_reduce_loop_indices
.
size
();
++
i
)
{
if
(
thread_binded_reduce_loop_indices
[
i
-
1
]
+
1
!=
thread_binded_reduce_loop_indices
[
i
])
{
return
false
;
}
}
return
true
;
}
void
Visit
(
ir
::
Expr
*
expr
)
{
ir
::
IRMutator
<>::
Visit
(
expr
,
expr
);
}
void
Visit
(
const
ir
::
_LoweredFunc_
*
expr
,
ir
::
Expr
*
op
)
override
{
ir
::
IRMutator
<>::
Visit
(
expr
,
op
);
if
(
std
::
find_if
(
op
->
as_lowered_func
()
->
temp_bufs
.
begin
(),
op
->
as_lowered_func
()
->
temp_bufs
.
end
(),
[
&
](
const
ir
::
Buffer
&
buf
)
->
bool
{
for
(
auto
&
tmp_buf
:
shm_buffer_
)
{
if
(
buf
->
name
==
tmp_buf
->
name
)
return
true
;
}
return
false
;
})
==
op
->
as_lowered_func
()
->
temp_bufs
.
end
())
op
->
as_lowered_func
()
->
temp_bufs
.
insert
(
op
->
as_lowered_func
()
->
temp_bufs
.
end
(),
shm_buffer_
.
begin
(),
shm_buffer_
.
end
());
shm_buffer_
.
clear
();
}
void
Visit
(
const
ir
::
ScheduleBlockRealize
*
expr
,
ir
::
Expr
*
op
)
override
{
if
(
!
CanReplace
(
expr
))
{
VLOG
(
6
)
<<
"Can't replace cross thread reduction: "
<<
*
op
;
IRMutator
::
Visit
(
expr
,
op
);
return
;
}
VLOG
(
6
)
<<
"Can replace cross thread reduction: "
<<
*
op
;
const
ir
::
ScheduleBlock
*
schedule_block
=
expr
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
CHECK_NOTNULL
(
schedule_block
);
ir
::
Expr
original_update_body
=
schedule_block
->
body
;
ir
::
Expr
original_update_stmt
;
CHECK
(
original_update_body
.
As
<
ir
::
Block
>
()
||
original_update_body
.
As
<
ir
::
Store
>
());
if
(
original_update_body
.
As
<
ir
::
Block
>
())
{
CHECK_EQ
(
original_update_body
.
As
<
ir
::
Block
>
()
->
stmts
.
size
(),
1
);
original_update_stmt
=
original_update_body
.
As
<
ir
::
Block
>
()
->
stmts
[
0
];
}
else
if
(
original_update_body
.
As
<
ir
::
Store
>
())
{
original_update_stmt
=
original_update_body
;
}
#define REPLACE_TO_EXTERNAL_CALL(Op) \
if (original_update_stmt.As<ir::Store>()->value.As<Op>()) { \
auto* node = original_update_stmt.As<ir::Store>()->value.As<Op>(); \
CHECK(node); \
auto& operand = node->b(); \
std::string reduce_func_name = \
hlir::pe::CrossThreadReduceExternalFuncName( \
original_update_stmt.As<ir::Store>()->value, \
operand.As<ir::Load>()->tensor); \
auto tmp_dtype = operand.As<ir::Load>()->tensor.as_tensor()->type(); \
auto tmp_buffer = ir::_Buffer_::Make( \
"shm32_" + hlir::pe::Type2StrForReduce(tmp_dtype) + "_reduce", \
{ir::Expr(32)}); \
tmp_buffer->dtype = tmp_dtype; \
tmp_buffer->memory_type = ir::MemoryType::GPUShared; \
shm_buffer_.insert(tmp_buffer); \
original_update_stmt.As<ir::Store>()->value = \
lang::CallExtern(reduce_func_name, {node->b(), tmp_buffer}); \
}
REPLACE_TO_EXTERNAL_CALL
(
ir
::
Add
)
REPLACE_TO_EXTERNAL_CALL
(
ir
::
Mul
)
REPLACE_TO_EXTERNAL_CALL
(
ir
::
Max
)
REPLACE_TO_EXTERNAL_CALL
(
ir
::
Min
)
REPLACE_TO_EXTERNAL_CALL
(
ir
::
And
)
REPLACE_TO_EXTERNAL_CALL
(
ir
::
Or
)
#undef REPLACE_TO_EXTERNAL_CALL
VLOG
(
6
)
<<
"Replace cross thread reduction: "
<<
*
op
;
IRMutator
::
Visit
(
expr
,
op
);
}
void
Visit
(
const
ir
::
For
*
expr
,
ir
::
Expr
*
op
)
override
{
cur_loops_
.
push_back
(
*
op
);
IRMutator
::
Visit
(
expr
,
op
);
cur_loops_
.
pop_back
();
}
private:
std
::
vector
<
ir
::
Expr
>
cur_loops_
;
};
}
// namespace
void
ReplaceCrossThreadReduction
(
Expr
*
e
)
{
CrossThreadReductionReplacer
()(
e
);
}
}
// namespace optim
}
// namespace cinn
paddle/cinn/optim/replace_cross_thread_reduction.h
0 → 100644
View file @
01a10755
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* This file implements the strategy to remove the unnecessary nested block.
*/
#pragma once
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/ir.h"
namespace
cinn
{
namespace
optim
{
/**
* Replace cross thread reduction to external call.
*/
void
ReplaceCrossThreadReduction
(
Expr
*
e
);
}
// namespace optim
}
// namespace cinn
paddle/cinn/optim/replace_cross_thread_reduction_test.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/replace_cross_thread_reduction.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
namespace
optim
{
TEST
(
CrossThreadReductionReplacer
,
basic
)
{
#ifdef CINN_WITH_CUDA
Context
::
Global
().
ResetNameId
();
Placeholder
<
float
>
A
(
"A"
,
{
Expr
(
64
),
Expr
(
128
)});
Target
target
=
common
::
DefaultNVGPUTarget
();
Module
::
Builder
builder
(
"reduce_sum"
,
target
);
Var
reduce_j
(
128
,
"reduce_j"
);
ir
::
Tensor
B
=
Compute
(
{
Expr
(
64
)},
[
&
](
Var
i
)
{
return
lang
::
ReduceSum
(
A
(
i
,
reduce_j
),
{
reduce_j
});
},
"B"
);
ast_gen_ius
::
TensorGroup
tensor_group
({
A
,
B
});
auto
func
=
lang
::
LowerToAst
(
"reduce_sum"
,
{
A
,
B
},
&
tensor_group
);
VLOG
(
6
)
<<
"original func
\n
"
<<
func
;
ir
::
ModuleExpr
mod_expr
({
func
->
body
});
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
Bind
(
ir_sch
.
GetLoops
(
"B"
)[
0
],
"blockIdx.x"
);
ir_sch
.
Bind
(
ir_sch
.
GetLoops
(
"B"
)[
1
],
"threadIdx.x"
);
ir
::
Expr
new_func
=
ir_sch
.
GetModule
().
GetExprs
()[
0
];
VLOG
(
6
)
<<
"After Bind: "
<<
new_func
;
ReplaceCrossThreadReduction
(
&
new_func
);
VLOG
(
6
)
<<
"After ReplaceCrossThreadReduction: "
<<
new_func
;
EXPECT_EQ
(
utils
::
GetStreamCnt
(
new_func
),
utils
::
Trim
(
R"ROC({
ScheduleBlock(root)
{
thread_bind[blockIdx.x] for (i, 0, 64)
{
ScheduleBlock(B__reduce_init)
{
i0 = axis.bind(i)
B__reduce_init[i0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_j, 0, 128)
{
ScheduleBlock(B)
{
i0_0, i1 = axis.bind(i, reduce_j)
B[i0_0] = cinn_block_reduce_sum_fp32_internal_shm(A[i0_0, i1], _Buffer_<cinn_buffer_t*: 32>(shm32__fp32_reduce))
}
}
}
}
}
)ROC"
));
#endif
}
}
// namespace optim
}
// namespace cinn
paddle/cinn/optim/replace_var_with_expr.cc
View file @
01a10755
...
...
@@ -16,11 +16,11 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_const_param_to_integer.h"
...
...
@@ -41,7 +41,7 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<> {
private:
void
Visit
(
const
ir
::
_Var_
*
expr
,
Expr
*
op
)
override
{
if
(
expr
->
name
==
var_
->
name
&&
(
do_replace_
||
visit_all_
))
{
auto
copied
=
IRCopy
(
expr_
);
auto
copied
=
ir
::
ir_utils
::
IRCopy
(
expr_
);
*
op
=
copied
;
}
}
...
...
paddle/cinn/optim/tensor_write_tell.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/tensor_write_tell.h"
namespace
cinn
{
namespace
optim
{}
// namespace optim
}
// namespace cinn
paddle/cinn/optim/tensor_write_tell.h
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <string>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace
cinn
{
namespace
optim
{
struct
TensorWriteTeller
:
public
ir
::
IRMutator
<
const
Expr
*>
{
//! Collect the write info in \p op.
void
Collect
(
const
Expr
*
op
)
{
Visit
(
op
,
op
);
}
bool
IsWrite
(
const
std
::
string
&
tensor_name
)
const
{
return
tensor_written
.
count
(
tensor_name
);
}
private:
std
::
set
<
std
::
string
>
tensor_written
;
void
Visit
(
const
Expr
*
expr
,
const
Expr
*
op
)
override
{
IRMutator
::
Visit
(
expr
,
op
);
}
void
Visit
(
const
ir
::
Store
*
expr
,
const
Expr
*
op
)
override
{
auto
*
node
=
op
->
As
<
ir
::
Store
>
();
CHECK
(
node
);
auto
*
tensor
=
node
->
tensor
.
As
<
ir
::
_Tensor_
>
();
CHECK
(
tensor
);
tensor_written
.
insert
(
tensor
->
name
);
IRMutator
::
Visit
(
expr
,
op
);
}
void
Visit
(
const
ir
::
_Tensor_
*
op
,
const
Expr
*
expr
)
override
{
auto
*
node
=
expr
->
As
<
ir
::
_Tensor_
>
();
if
(
node
->
is_call_node
())
{
tensor_written
.
insert
(
node
->
name
);
}
}
};
}
// namespace optim
}
// namespace cinn
paddle/cinn/optim/transform_gpu_forloop.cc
View file @
01a10755
...
...
@@ -24,9 +24,9 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/poly/isl_utils.h"
...
...
@@ -185,7 +185,7 @@ class RestructureVarNodes : public ir::IRMutator<> {
void
Visit
(
const
ir
::
Load
*
load
,
Expr
*
op
)
override
{
std
::
vector
<
ir
::
Expr
>
indices_copied
;
for
(
const
ir
::
Expr
&
indice
:
load
->
indices
)
{
indices_copied
.
push_back
(
IRCopy
(
indice
));
indices_copied
.
push_back
(
ir
::
ir_utils
::
IRCopy
(
indice
));
}
op
->
As
<
ir
::
Load
>
()
->
indices
=
indices_copied
;
...
...
@@ -195,7 +195,7 @@ class RestructureVarNodes : public ir::IRMutator<> {
void
Visit
(
const
ir
::
Store
*
store
,
Expr
*
op
)
override
{
std
::
vector
<
ir
::
Expr
>
indices_copied
;
for
(
const
ir
::
Expr
&
indice
:
store
->
indices
)
{
indices_copied
.
push_back
(
IRCopy
(
indice
));
indices_copied
.
push_back
(
ir
::
ir_utils
::
IRCopy
(
indice
));
}
op
->
As
<
ir
::
Store
>
()
->
indices
=
indices_copied
;
...
...
@@ -396,7 +396,7 @@ class ReplaceLoopVarToGpu : public ir::IRMutator<> {
auto
bind_info
=
for_ir
->
bind_info
();
std
::
string
var_name
=
""
;
if
(
bind_info
.
offset
=
=
0
)
if
(
bind_info
.
offset
<
=
0
)
var_name
=
"x"
;
else
if
(
bind_info
.
offset
==
1
)
var_name
=
"y"
;
...
...
@@ -585,8 +585,8 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> {
}
int
BufferSize
(
ir
::
Expr
indice
)
{
auto
copy
=
IRCopy
(
indice
);
auto
vars
=
ir
::
CollectIRNodesInOrder
(
auto
copy
=
ir
::
ir_utils
::
IRCopy
(
indice
);
auto
vars
=
ir
::
ir_utils
::
CollectIRNodesInOrder
(
copy
,
[](
const
ir
::
Expr
*
expr
)
{
return
expr
->
As
<
ir
::
_Var_
>
();
});
int
max_range
=
1
;
...
...
@@ -598,7 +598,7 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> {
auto
extent
=
loop_2_extent_
.
find
(
var
->
name
)
->
second
;
for
(
int
idx
=
0
;
idx
<
extent
;
++
idx
)
{
auto
tmp
=
IRCopy
(
index
);
auto
tmp
=
ir
::
ir_utils
::
IRCopy
(
index
);
ReplaceVarWithExpr
(
&
tmp
,
var
,
Expr
(
idx
));
if
(
deep
==
vars
.
size
()
-
1
)
{
...
...
paddle/cinn/optim/transform_polyfor_to_for.cc
View file @
01a10755
...
...
@@ -21,11 +21,11 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/ir_simplify.h"
namespace
cinn
{
...
...
paddle/cinn/optim/unroll_loops.cc
View file @
01a10755
...
...
@@ -17,11 +17,11 @@
#include <utility>
#include <vector>
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_replace.h"
#include "paddle/cinn/ir/utils/ir_replace.h"
namespace
cinn
{
namespace
optim
{
...
...
@@ -94,8 +94,8 @@ struct UnrollMutator : public ir::IRMutator<Expr*> {
for
(
int
i
=
min
->
value
;
i
<
extent
->
value
;
i
++
)
{
Expr
start
=
op
->
min
+
i
;
body
.
push_back
(
optim
::
IRCopy
(
op
->
body
));
optim
::
IrReplace
(
&
body
.
back
(),
op
->
loop_var
,
start
);
body
.
push_back
(
ir
::
ir_utils
::
IRCopy
(
op
->
body
));
cinn
::
ir
::
ir_utils
::
IrReplace
(
&
body
.
back
(),
op
->
loop_var
,
start
);
}
*
expr
=
ir
::
Block
::
Make
(
body
);
...
...
paddle/cinn/optim/var_mod_simplify.cc
View file @
01a10755
...
...
@@ -17,8 +17,8 @@
#include <absl/container/flat_hash_map.h>
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace
cinn
::
optim
{
...
...
paddle/cinn/optim/vectorize_loops.cc
View file @
01a10755
...
...
@@ -25,13 +25,12 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_replace.h"
#include "paddle/cinn/ir/utils/ir_replace.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/optim/unroll_loops.h"
#include "paddle/cinn/utils/functional.h"
...
...
@@ -130,7 +129,8 @@ class TensorVectorizeTeller : public ir::IRMutator<const Expr *> {
// the iter val must appear in the last index
if
(
indices
.
empty
()
||
ir
::
CollectIRNodes
(
indices
.
back
(),
find_matched_var_fn
).
empty
())
{
ir
::
ir_utils
::
CollectIRNodes
(
indices
.
back
(),
find_matched_var_fn
)
.
empty
())
{
VLOG
(
5
)
<<
"Loop var:"
<<
iter_var_
->
name
<<
" is not used in the last index"
;
return
false
;
...
...
@@ -138,7 +138,8 @@ class TensorVectorizeTeller : public ir::IRMutator<const Expr *> {
// the iter val can't appear in mulitple indices
for
(
int
i
=
0
;
i
<
indices
.
size
()
-
1
;
++
i
)
{
auto
repeat_found
=
ir
::
CollectIRNodes
(
indices
[
i
],
find_matched_var_fn
);
auto
repeat_found
=
ir
::
ir_utils
::
CollectIRNodes
(
indices
[
i
],
find_matched_var_fn
);
if
(
!
repeat_found
.
empty
())
{
VLOG
(
5
)
<<
"Loop var:"
<<
iter_var_
->
name
<<
" is used at more than last index, current:"
<<
i
;
...
...
@@ -147,12 +148,12 @@ class TensorVectorizeTeller : public ir::IRMutator<const Expr *> {
}
// check tensor accessed sequentially by comparing index one by one
Expr
first_idx
=
optim
::
IRCopy
(
indices
.
back
());
optim
::
IrReplace
(
&
first_idx
,
Expr
(
iter_var_
),
Expr
(
0
));
Expr
first_idx
=
ir
::
ir_utils
::
IRCopy
(
indices
.
back
());
cinn
::
ir
::
ir_utils
::
IrReplace
(
&
first_idx
,
Expr
(
iter_var_
),
Expr
(
0
));
const
auto
&
interval
=
var_intervals_
->
at
(
iter_var_
->
name
);
for
(
int
i
=
1
;
i
<
interval
.
r
;
++
i
)
{
Expr
next_idx
=
optim
::
IRCopy
(
indices
.
back
());
optim
::
IrReplace
(
&
next_idx
,
Expr
(
iter_var_
),
Expr
(
i
));
Expr
next_idx
=
ir
::
ir_utils
::
IRCopy
(
indices
.
back
());
cinn
::
ir
::
ir_utils
::
IrReplace
(
&
next_idx
,
Expr
(
iter_var_
),
Expr
(
i
));
auto
gap
=
common
::
AutoSimplify
(
Expr
(
next_idx
-
first_idx
));
if
(
!
gap
.
As
<
IntImm
>
()
||
gap
.
as_int32
()
!=
i
)
{
VLOG
(
5
)
<<
"Tensor:"
<<
tensor
->
name
...
...
@@ -185,7 +186,7 @@ class CudaVectorizer : public IRMutator<Expr *> {
const
Var
iter_var_
;
// the loop var of the vecotrized loop
const
int
factor_
;
// the factor for vectorize
TensorWriteTeller
write_teller_
;
std
::
set
<
std
::
string
>
write_teller_
;
TensorVectorizeTeller
vectorized_teller_
;
absl
::
flat_hash_map
<
std
::
string
,
Var
>
tensor2vectorized_vars_
;
...
...
@@ -215,7 +216,7 @@ class CudaVectorizer : public IRMutator<Expr *> {
}
void
Visit
(
Expr
*
expr
)
{
write_teller_
.
Collect
(
expr
);
write_teller_
=
ir
::
ir_utils
::
CollectTensorNeedsWrite
(
expr
);
vectorized_teller_
.
Collect
(
expr
);
IRMutator
<
Expr
*>::
Visit
(
expr
,
expr
);
}
...
...
@@ -289,7 +290,7 @@ class CudaVectorizer : public IRMutator<Expr *> {
const
std
::
vector
<
Expr
>
&
indices
,
bool
is_store
)
{
auto
*
node
=
tensor
.
As
<
ir
::
_Tensor_
>
();
bool
is_const
=
!
write_teller_
.
IsWrite
(
node
->
name
);
bool
is_const
=
!
write_teller_
.
count
(
node
->
name
);
// generate the corresponding vector type
Type
scalar_type
=
tensor
->
type
().
ElementOf
();
...
...
@@ -309,7 +310,8 @@ class CudaVectorizer : public IRMutator<Expr *> {
// generate a get_addr expr to get the address of the tensor
Expr
converted_tensor
=
Load
::
Make
(
tensor
,
indices
);
optim
::
IrReplace
(
&
converted_tensor
,
iter_var_
,
Expr
(
int32_t
(
0
)));
cinn
::
ir
::
ir_utils
::
IrReplace
(
&
converted_tensor
,
iter_var_
,
Expr
(
int32_t
(
0
)));
auto
get_addr
=
ir
::
intrinsics
::
GetAddr
::
Make
(
converted_tensor
);
// generate a let expression to cast the tensor into the local vector
...
...
@@ -798,7 +800,7 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
cuda_vectorizer
.
Visit
(
&
new_forloop
->
body
);
// unroll the new forloop to compute each element of the vector
// iteratively
auto
copied_loop
=
optim
::
IRCopy
(
_new_forloop
);
auto
copied_loop
=
ir
::
ir_utils
::
IRCopy
(
_new_forloop
);
copied_loop
.
As
<
ir
::
For
>
()
->
set_unrolled
();
optim
::
UnrollLoop
(
&
copied_loop
);
// add cast exprs of vector type in the front of vectorized forloop,
...
...
@@ -881,13 +883,14 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
Var
new_iterator_outer
(
common
::
UniqName
(
outer_for
->
loop_var
->
name
+
"_s"
));
Expr
inner_for_b
=
Block
::
Make
({
For
::
Make
(
new_iterator_inner
,
inner_for
->
min
,
b
,
ForType
::
Serial
,
DeviceAPI
::
UNK
,
IRCopy
(
inner_for
->
body
))});
optim
::
IrReplace
(
Expr
inner_for_b
=
Block
::
Make
({
For
::
Make
(
new_iterator_inner
,
inner_for
->
min
,
b
,
ForType
::
Serial
,
DeviceAPI
::
UNK
,
ir
::
ir_utils
::
IRCopy
(
inner_for
->
body
))});
cinn
::
ir
::
ir_utils
::
IrReplace
(
&
inner_for_b
,
inner_for
->
loop_var
,
Expr
(
new_iterator_inner
));
Expr
out_for_b
=
For
::
Make
(
new_iterator_outer
,
...
...
@@ -897,7 +900,7 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
outer_for
->
device_api
,
inner_for_b
,
outer_for
->
vectorize_info
());
optim
::
IrReplace
(
cinn
::
ir
::
ir_utils
::
IrReplace
(
&
out_for_b
,
outer_for
->
loop_var
,
Expr
(
new_iterator_outer
));
*
expr
=
Block
::
Make
({
out_for_a
,
out_for_b
});
VLOG
(
2
)
<<
*
expr
;
...
...
@@ -959,7 +962,8 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
}
else
{
new_index
=
Expr
(
forloop
->
loop_var
)
*
factor
+
Expr
(
new_iterator
);
}
optim
::
IrReplace
(
&
forloop
->
body
,
forloop
->
loop_var
,
new_index
);
cinn
::
ir
::
ir_utils
::
IrReplace
(
&
forloop
->
body
,
forloop
->
loop_var
,
new_index
);
auto
new_forloop
=
For
::
Make
(
new_iterator
,
forloop
->
min
,
make_const
(
factor
),
...
...
paddle/cinn/optim/vectorize_loops.h
View file @
01a10755
...
...
@@ -14,7 +14,7 @@
#pragma once
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
namespace
cinn
{
namespace
optim
{
...
...
paddle/cinn/poly/ast_gen.cc
View file @
01a10755
...
...
@@ -20,7 +20,7 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/poly/domain_add_unit_loop_mutator.h"
#include "paddle/cinn/poly/isl_utils.h"
...
...
paddle/cinn/poly/ast_gen_test.cc
View file @
01a10755
...
...
@@ -22,7 +22,7 @@
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/placeholder.h"
...
...
paddle/cinn/poly/dim.cc
View file @
01a10755
...
...
@@ -14,7 +14,7 @@
#include "paddle/cinn/poly/dim.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/utils/string.h"
...
...
paddle/cinn/poly/domain.cc
View file @
01a10755
...
...
@@ -23,7 +23,7 @@
#include <unordered_set>
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/ir/
utils/
ir_visitor.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
@@ -70,8 +70,8 @@ void Domain::ExtractParams() {
std
::
unordered_set
<
std
::
string
>
var_names
;
auto
collect_param_fn
=
[
&
](
Expr
&
e
)
{
if
(
!
e
.
is_constant
())
{
auto
vars
=
ir
::
CollectIRNodes
(
e
,
[](
const
Expr
*
e
)
{
return
e
->
is_var
();
});
auto
vars
=
ir
::
ir_utils
::
CollectIRNodes
(
e
,
[](
const
Expr
*
e
)
{
return
e
->
is_var
();
});
for
(
auto
&
var
:
vars
)
var_names
.
insert
(
var
.
As
<
ir
::
_Var_
>
()
->
name
);
}
};
...
...
paddle/cinn/poly/domain_add_unit_loop_mutator.cc
View file @
01a10755
...
...
@@ -20,7 +20,7 @@
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
Prev
1
…
22
23
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment