Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
565
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
662 additions
and
45 deletions
+662
-45
paddle/cinn/common/cinn_value_test.cc
paddle/cinn/common/cinn_value_test.cc
+1
-1
paddle/cinn/common/context.cc
paddle/cinn/common/context.cc
+3
-3
paddle/cinn/common/context.h
paddle/cinn/common/context.h
+26
-5
paddle/cinn/common/dev_info_base.h
paddle/cinn/common/dev_info_base.h
+30
-0
paddle/cinn/common/dev_info_manager.h
paddle/cinn/common/dev_info_manager.h
+67
-0
paddle/cinn/common/equation_graph_topo_walker.h
paddle/cinn/common/equation_graph_topo_walker.h
+174
-0
paddle/cinn/common/equation_graph_topo_walker_test.cc
paddle/cinn/common/equation_graph_topo_walker_test.cc
+116
-0
paddle/cinn/common/ir_util.cc
paddle/cinn/common/ir_util.cc
+8
-9
paddle/cinn/common/macros.h
paddle/cinn/common/macros.h
+7
-7
paddle/cinn/common/make_is_reachable_from_src_predicator.h
paddle/cinn/common/make_is_reachable_from_src_predicator.h
+34
-0
paddle/cinn/common/make_subgraph_walker.h
paddle/cinn/common/make_subgraph_walker.h
+67
-0
paddle/cinn/common/nvgpu_dev_info.cc
paddle/cinn/common/nvgpu_dev_info.cc
+50
-0
paddle/cinn/common/nvgpu_dev_info.h
paddle/cinn/common/nvgpu_dev_info.h
+48
-0
paddle/cinn/common/target.cc
paddle/cinn/common/target.cc
+9
-0
paddle/cinn/common/target.h
paddle/cinn/common/target.h
+2
-2
paddle/cinn/common/topo_walker.h
paddle/cinn/common/topo_walker.h
+10
-10
paddle/cinn/frontend/computation.cc
paddle/cinn/frontend/computation.cc
+6
-4
paddle/cinn/frontend/computation.h
paddle/cinn/frontend/computation.h
+1
-2
paddle/cinn/frontend/computation_test.cc
paddle/cinn/frontend/computation_test.cc
+1
-1
paddle/cinn/frontend/decomposer/activation_test.cc
paddle/cinn/frontend/decomposer/activation_test.cc
+2
-1
No files found.
Too many changes to show.
To preserve performance only
565 of 565+
files are displayed.
Plain diff
Email patch
paddle/cinn/common/cinn_value_test.cc
View file @
01a10755
...
...
@@ -19,7 +19,7 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace
cinn
{
namespace
common
{
...
...
paddle/cinn/common/context.cc
View file @
01a10755
...
...
@@ -78,7 +78,7 @@ std::string NameGenerator::New(const std::string& name_hint) {
}
// namespace common
DEFINE_bool
(
cinn_runtime_display_debug_info
,
PD_
DEFINE_bool
(
cinn_runtime_display_debug_info
,
false
,
"Whether to display debug information in runtime"
);
}
// namespace cinn
paddle/cinn/common/context.h
View file @
01a10755
...
...
@@ -14,7 +14,6 @@
#pragma once
#include <absl/types/any.h>
#include <gflags/gflags.h>
#include <isl/cpp.h>
#include <mutex>
...
...
@@ -25,10 +24,11 @@
#include "paddle/cinn/common/debug_manager.h"
#include "paddle/cinn/common/info_registry.h"
#include "paddle/cinn/common/target.h"
#include "paddle/utils/flags.h"
namespace
cinn
{
DECLARE_bool
(
cinn_runtime_display_debug_info
);
PD_
DECLARE_bool
(
cinn_runtime_display_debug_info
);
namespace
ir
{
class
Expr
;
...
...
@@ -52,6 +52,22 @@ struct NameGenerator {
mutable
std
::
mutex
mutex_
;
};
struct
PrettyNamer
{
const
std
::
string
&
GetOrNew
(
const
size_t
hash_key
,
const
std
::
string
&
name_hint
)
{
if
(
pretty_names_
.
find
(
hash_key
)
==
pretty_names_
.
end
())
{
pretty_names_
[
hash_key
]
=
name_generator_
.
New
(
name_hint
);
}
return
pretty_names_
.
at
(
hash_key
);
}
NameGenerator
&
GetNameGenerator
()
{
return
name_generator_
;
}
private:
absl
::
flat_hash_map
<
size_t
,
std
::
string
>
pretty_names_
;
NameGenerator
name_generator_
;
};
class
Context
{
public:
static
Context
&
Global
();
...
...
@@ -61,10 +77,15 @@ class Context {
* @param name_hint The prefix.
*/
std
::
string
NewName
(
const
std
::
string
&
name_hint
)
{
return
name_g
enerator
_
.
New
(
name_hint
);
return
pretty_namer_
.
GetNameG
enerator
()
.
New
(
name_hint
);
}
void
ResetNameId
()
{
name_generator_
.
ResetID
();
}
std
::
string
PrettyUniqName
(
const
size_t
hash_key
,
const
std
::
string
&
name_hint
)
{
return
pretty_namer_
.
GetOrNew
(
hash_key
,
name_hint
);
}
void
ResetNameId
()
{
pretty_namer_
.
GetNameGenerator
().
ResetID
();
}
const
std
::
vector
<
std
::
string
>&
runtime_include_dir
();
...
...
@@ -82,7 +103,7 @@ class Context {
private:
Context
()
=
default
;
NameGenerator
name_generato
r_
;
PrettyNamer
pretty_name
r_
;
std
::
vector
<
std
::
string
>
runtime_include_dir_
;
mutable
std
::
mutex
mutex_
;
...
...
paddle/cinn/common/dev_info_base.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
cinn
{
namespace
common
{
class
DevInfoBase
{
public:
explicit
DevInfoBase
(
int
device_num
=
0
)
:
device_num_
(
device_num
)
{}
virtual
~
DevInfoBase
()
=
default
;
protected:
int
device_num_
;
};
}
// namespace common
}
// namespace cinn
paddle/cinn/common/dev_info_manager.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/common/dev_info_base.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/common/nvgpu_dev_info.h"
#include "paddle/cinn/common/target.h"
namespace
cinn
{
namespace
common
{
template
<
Target
::
Arch
arch
>
struct
GetDevType
{
using
DevType
=
DevInfoBase
;
};
// Extra device should be added here
class
NVGPUDevInfo
;
template
<
>
struct
GetDevType
<
Target
::
Arch
::
NVGPU
>
{
using
DevType
=
NVGPUDevInfo
;
};
template
<
Target
::
Arch
arch
>
class
DevInfoMgr
final
{
private:
explicit
DevInfoMgr
(
int
device_num
=
0
)
:
device_num_
(
device_num
)
{
impl_
=
std
::
make_unique
<
typename
GetDevType
<
arch
>::
DevType
>
(
device_num
);
}
std
::
unique_ptr
<
DevInfoBase
>
impl_
;
int
device_num_
;
public:
static
DevInfoMgr
<
arch
>
GetDevInfo
(
int
device_num
=
0
)
{
return
DevInfoMgr
(
device_num
);
}
using
RetType
=
typename
GetDevType
<
arch
>::
DevType
;
const
RetType
*
operator
->
()
const
{
CHECK
(
!
std
::
is_void
<
RetType
>
())
<<
"Current device can't be recognized!
\n
"
;
return
dynamic_cast
<
const
RetType
*>
(
impl_
.
get
());
}
RetType
*
operator
->
()
{
CHECK
(
!
std
::
is_void
<
RetType
>
())
<<
"Current device can't be recognized!
\n
"
;
return
dynamic_cast
<
RetType
*>
(
impl_
.
get
());
}
};
}
// namespace common
}
// namespace cinn
paddle/cinn/common/equation_graph_topo_walker.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include <iostream>
#include <queue>
#include <tuple>
#include <unordered_set>
#include <vector>
#include "paddle/cinn/common/bfs_walker.h"
namespace
cinn
{
template
<
typename
VT
,
typename
FT
>
class
EquationGraphTopoWalker
final
{
public:
using
VariableVisitorT
=
std
::
function
<
void
(
VT
)
>
;
using
FunctionVisitorT
=
std
::
function
<
void
(
FT
)
>
;
using
F4VVisitor
=
std
::
function
<
void
(
VT
,
const
FunctionVisitorT
&
)
>
;
using
V4FVisitor
=
std
::
function
<
void
(
FT
,
const
VariableVisitorT
&
)
>
;
EquationGraphTopoWalker
(
const
F4VVisitor
&
NextFunctionsVisitor
,
const
V4FVisitor
&
InputVariablesVisitor
,
const
V4FVisitor
&
OutputVariablesVisitor
)
:
VisitNextFunctions
(
NextFunctionsVisitor
),
VisitInputVariables
(
InputVariablesVisitor
),
VisitOutputVariables
(
OutputVariablesVisitor
)
{}
~
EquationGraphTopoWalker
()
=
default
;
static
F4VVisitor
Merge
(
const
F4VVisitor
&
lhs
,
const
F4VVisitor
&
rhs
)
{
return
[
=
](
VT
variable
,
const
FunctionVisitorT
&
Visit
)
{
lhs
(
variable
,
Visit
);
rhs
(
variable
,
Visit
);
};
}
static
V4FVisitor
Merge
(
const
V4FVisitor
&
lhs
,
const
V4FVisitor
&
rhs
)
{
return
[
=
](
FT
function
,
const
VariableVisitorT
&
Visit
)
{
lhs
(
function
,
Visit
);
rhs
(
function
,
Visit
);
};
}
EquationGraphTopoWalker
Merge
(
const
EquationGraphTopoWalker
&
that
)
const
{
return
{
Merge
(
this
->
VisitNextFunctions
,
that
.
VisitNextFunctions
),
Merge
(
this
->
VisitInputVariables
,
that
.
VisitInputVariables
),
Merge
(
this
->
VisitOutputVariables
,
that
.
VisitOutputVariables
)};
}
void
WalkVariable
(
VT
start
,
const
VariableVisitorT
&
VariableVisitor
)
const
{
std
::
array
<
VT
,
1
>
starts
{
start
};
(
*
this
)(
starts
.
begin
(),
starts
.
end
(),
VariableVisitor
,
[
&
](
FT
)
{});
}
template
<
typename
VarIterT
>
void
WalkVariable
(
VarIterT
begin
,
VarIterT
end
,
const
VariableVisitorT
&
VariableVisitor
)
const
{
(
*
this
)(
begin
,
end
,
VariableVisitor
,
[
&
](
FT
)
{});
}
void
WalkFunction
(
VT
start
,
const
FunctionVisitorT
&
FunctionVisitor
)
const
{
std
::
array
<
VT
,
1
>
starts
{
start
};
(
*
this
)(
starts
.
begin
(),
starts
.
end
(),
[
&
](
VT
)
{},
FunctionVisitor
);
}
template
<
typename
VarIterT
>
void
WalkFunction
(
VarIterT
begin
,
VarIterT
end
,
const
FunctionVisitorT
&
FunctionVisitor
)
const
{
(
*
this
)(
begin
,
end
,
[
&
](
VT
)
{},
FunctionVisitor
);
}
void
BfsWalkFunction
(
VT
variable
,
const
FunctionVisitorT
&
FunctionVisitor
)
const
{
std
::
array
<
VT
,
1
>
array
{
variable
};
BfsWalkFunction
(
array
.
begin
(),
array
.
end
(),
FunctionVisitor
);
}
template
<
typename
VarIterT
>
void
BfsWalkFunction
(
VarIterT
begin
,
VarIterT
end
,
const
FunctionVisitorT
&
FunctionVisitor
)
const
{
using
F4FVisitor
=
std
::
function
<
void
(
FT
,
const
FunctionVisitorT
&
)
>
;
F4FVisitor
BfsVisitNextFunction
=
[
&
](
FT
f
,
const
FunctionVisitorT
&
DoEach
)
{
VisitInputVariables
(
f
,
[
&
](
VT
variable
)
{
VisitNextFunctions
(
variable
,
DoEach
);
});
VisitOutputVariables
(
f
,
[
&
](
VT
variable
)
{
VisitNextFunctions
(
variable
,
DoEach
);
});
};
std
::
vector
<
FT
>
starts
{};
for
(
VarIterT
iter
=
begin
;
iter
!=
end
;
++
iter
)
{
VisitNextFunctions
(
*
iter
,
[
&
](
FT
f
)
{
starts
.
emplace_back
(
f
);
});
}
common
::
BfsWalker
<
FT
>
bfs_walker
{
BfsVisitNextFunction
};
bfs_walker
(
starts
.
begin
(),
starts
.
end
(),
FunctionVisitor
);
}
template
<
typename
VarIterT
>
void
operator
()(
VarIterT
begin
,
VarIterT
end
,
const
VariableVisitorT
&
VariableVisitor
,
const
FunctionVisitorT
&
FunctionVisitor
)
const
{
std
::
queue
<
VT
>
variables_queue
{};
std
::
unordered_set
<
VT
>
queued_variables
{};
std
::
queue
<
FT
>
functions_queue
{};
std
::
unordered_set
<
FT
>
queued_functions
{};
const
auto
&
TryEnqueueVaraible
=
[
&
](
VT
variable
)
{
if
(
queued_variables
.
count
(
variable
)
==
0
)
{
variables_queue
.
push
(
variable
);
queued_variables
.
insert
(
variable
);
}
};
const
auto
&
TryEnqueueFunction
=
[
&
](
FT
function
)
{
if
(
queued_functions
.
count
(
function
)
==
0
)
{
functions_queue
.
push
(
function
);
queued_functions
.
insert
(
function
);
}
};
for
(
VarIterT
iter
=
begin
;
iter
!=
end
;
++
iter
)
{
TryEnqueueVaraible
(
*
iter
);
}
while
(
!
functions_queue
.
empty
()
||
!
variables_queue
.
empty
())
{
if
(
!
functions_queue
.
empty
())
{
FT
function
=
functions_queue
.
front
();
functions_queue
.
pop
();
FunctionVisitor
(
function
);
VisitOutputVariables
(
function
,
TryEnqueueVaraible
);
}
if
(
!
variables_queue
.
empty
())
{
VT
variable
=
variables_queue
.
front
();
variables_queue
.
pop
();
VariableVisitor
(
variable
);
VisitNextFunctions
(
variable
,
[
&
](
FT
function
)
{
size_t
num_unfinished_inputs
=
0
;
VisitInputVariables
(
function
,
[
&
](
VT
in_variable
)
{
num_unfinished_inputs
+=
(
queued_variables
.
count
(
in_variable
)
>
0
?
0
:
1
);
});
if
(
num_unfinished_inputs
==
0
)
{
TryEnqueueFunction
(
function
);
}
});
}
}
}
// tNext [Function] <- Variable
F4VVisitor
VisitNextFunctions
;
// tIn [Variable] <- Function
V4FVisitor
VisitInputVariables
;
// tOut [Variable] <- Function
V4FVisitor
VisitOutputVariables
;
};
}
// namespace cinn
paddle/cinn/common/equation_graph_topo_walker_test.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO(yifan): Add unittest here
#include "paddle/cinn/common/equation_graph_topo_walker.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
namespace
adt
{
namespace
common
{
using
VT
=
int
;
using
FT
=
std
::
string
;
/*
Graph ex:
1-> "1->10" -> 10
2-> "2->20" -> 20
*/
TEST
(
EquationGraphTopoWalker
,
simple1
)
{
auto
F4V
=
[](
VT
variable
,
const
std
::
function
<
void
(
FT
)
>&
visitor
)
{
if
(
variable
==
1
)
{
visitor
(
"1->10"
);
}
else
if
(
variable
==
2
)
{
visitor
(
"2->20"
);
}
};
auto
InV4F
=
[](
FT
function
,
const
std
::
function
<
void
(
VT
)
>&
visitor
)
{
if
(
function
==
"1->10"
)
{
visitor
(
1
);
}
else
if
(
function
==
"2->20"
)
{
visitor
(
2
);
}
};
auto
OutV4F
=
[](
FT
function
,
const
std
::
function
<
void
(
VT
)
>&
visitor
)
{
if
(
function
==
"1->10"
)
{
visitor
(
10
);
}
else
if
(
function
==
"2->20"
)
{
visitor
(
20
);
}
};
cinn
::
EquationGraphTopoWalker
<
VT
,
FT
>
walker
(
F4V
,
InV4F
,
OutV4F
);
std
::
vector
<
FT
>
outputs
;
std
::
function
<
void
(
FT
)
>
FunctionVisitor
=
[
&
](
FT
function
)
{
outputs
.
push_back
(
function
);
};
walker
.
WalkFunction
(
1
,
FunctionVisitor
);
std
::
vector
<
FT
>
expected
{
"1->10"
};
EXPECT_TRUE
((
outputs
==
expected
));
}
/*
Graph ex:
1 -> "1->10, 1->11" -> 10
-> 11
2 -> "2->20" -> 20
3 -> "3->30, 3->31" -> 30
-> 31
*/
TEST
(
EquationGraphTopoWalker
,
simple2
)
{
auto
F4V
=
[](
VT
variable
,
const
std
::
function
<
void
(
FT
)
>&
visitor
)
{
if
(
variable
==
1
)
{
visitor
(
"1->10, 1->11"
);
}
else
if
(
variable
==
2
)
{
visitor
(
"2->20"
);
}
else
if
(
variable
==
3
)
{
visitor
(
"3->30, 3->31"
);
}
};
auto
InV4F
=
[](
FT
function
,
const
std
::
function
<
void
(
VT
)
>&
visitor
)
{
if
(
function
==
"1->10, 1->11"
)
{
visitor
(
1
);
}
else
if
(
function
==
"2->20"
)
{
visitor
(
2
);
}
else
if
(
function
==
"3->30, 3->31"
)
{
visitor
(
3
);
}
};
auto
OutV4F
=
[](
FT
function
,
const
std
::
function
<
void
(
VT
)
>&
visitor
)
{
if
(
function
==
"1->10, 1->11"
)
{
visitor
(
10
);
visitor
(
11
);
}
else
if
(
function
==
"2->20"
)
{
visitor
(
20
);
}
else
if
(
function
==
"3->30, 3->31"
)
{
visitor
(
30
);
visitor
(
31
);
}
};
cinn
::
EquationGraphTopoWalker
<
VT
,
FT
>
walker
(
F4V
,
InV4F
,
OutV4F
);
std
::
vector
<
VT
>
outputs
;
std
::
function
<
void
(
VT
)
>
VariableVisitor
=
[
&
](
VT
variable
)
{
outputs
.
push_back
(
variable
);
};
walker
.
WalkVariable
(
1
,
VariableVisitor
);
std
::
vector
<
VT
>
expected
{
1
,
10
,
11
};
EXPECT_TRUE
((
outputs
==
expected
));
}
}
// namespace common
}
// namespace adt
paddle/cinn/common/ir_util.cc
View file @
01a10755
...
...
@@ -18,10 +18,9 @@
#include <unordered_set>
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/cast_simplify.h"
namespace
cinn
{
namespace
common
{
...
...
@@ -147,7 +146,7 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
for
(
int
i
=
0
;
i
<
shape
.
size
();
i
++
)
{
CHECK_EQ
(
shape
[
i
].
type
(),
Int
(
32
));
Expr
indice_prod
=
indices
[
i
];
optim
::
Cast
Simplify
(
&
indice_prod
);
optim
::
Simplify
Cast
(
&
indice_prod
);
for
(
int
j
=
i
+
1
;
j
<
shape
.
size
();
j
++
)
{
indice_prod
=
RampRelatedMul
(
indice_prod
,
shape
[
j
]);
}
...
...
@@ -250,8 +249,8 @@ Expr or_all(const std::vector<Expr> &conds) {
}
void
CheckTensorUniqueInExpr
(
Expr
expr
)
{
auto
tensor_uniq
=
ir
::
CollectIRNodes
(
expr
,
[](
const
Expr
*
x
)
{
return
x
->
as_tensor
();
});
auto
tensor_uniq
=
ir
::
ir_utils
::
CollectIRNodes
(
expr
,
[](
const
Expr
*
x
)
{
return
x
->
as_tensor
();
});
absl
::
flat_hash_map
<
std
::
string
,
const
ir
::
_Tensor_
*>
tensor_names
;
for
(
auto
&
t
:
tensor_uniq
)
{
auto
*
tp
=
t
.
as_tensor
();
...
...
@@ -270,9 +269,9 @@ void CheckBufferUniqueInExpr(Expr expr) {
// the buffers exists in tensor and lowered functions.
CheckTensorUniqueInExpr
(
expr
);
auto
tensors
=
ir
::
CollectIRNodes
(
expr
,
[](
const
Expr
*
x
)
{
return
x
->
as_tensor
();
});
auto
funcs
=
ir
::
CollectIRNodes
(
auto
tensors
=
ir
::
ir_utils
::
CollectIRNodes
(
expr
,
[](
const
Expr
*
x
)
{
return
x
->
as_tensor
();
});
auto
funcs
=
ir
::
ir_utils
::
CollectIRNodes
(
expr
,
[](
const
Expr
*
x
)
{
return
x
->
as_lowered_func
();
});
absl
::
flat_hash_map
<
std
::
string
,
const
ir
::
_Buffer_
*>
buffer_name
;
...
...
paddle/cinn/common/macros.h
View file @
01a10755
...
...
@@ -69,8 +69,8 @@
#define USE_FUSION_PASS(pass_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_fusion_pass_##pass_name, \
"USE_
OP_ITSELF
must be called in global namespace"); \
extern int TouchFusionPassRegistrar_##pass_name(); \
[[maybe_unused]] static int __use_fusion_pass_##pass_name##_ = \
TouchFusionPassRegistrar_##pass_name()
__use_
cinn_
fusion_pass_##pass_name, \
"USE_
FUSION_PASS
must be called in global namespace");
\
extern int Touch
Cinn
FusionPassRegistrar_##pass_name();
\
[[maybe_unused]] static int __use_
cinn_
fusion_pass_##pass_name##_ = \
Touch
Cinn
FusionPassRegistrar_##pass_name()
paddle/cinn/common/make_is_reachable_from_src_predicator.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "paddle/cinn/common/topo_walker.h"
namespace
cinn
::
common
{
template
<
typename
NodeT
,
typename
IterT
>
std
::
function
<
bool
(
NodeT
)
>
MakeIsReachableFromSrcPredicator
(
const
TopoWalker
<
NodeT
>&
walker
,
IterT
src_begin
,
IterT
src_end
)
{
auto
nodes
=
std
::
make_shared
<
std
::
unordered_set
<
NodeT
>>
();
nodes
->
insert
(
src_begin
,
src_end
);
walker
(
src_begin
,
src_end
,
[
&
](
NodeT
node
)
{
nodes
->
insert
(
node
);
});
return
[
nodes
](
NodeT
node
)
{
return
nodes
->
count
(
node
)
>
0
;
};
}
}
// namespace cinn::common
paddle/cinn/common/make_subgraph_walker.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "paddle/cinn/common/make_is_reachable_from_src_predicator.h"
#include "paddle/cinn/common/topo_walker.h"
namespace
cinn
::
common
{
template
<
typename
NodeT
,
typename
IterT
>
common
::
TopoWalker
<
NodeT
>
MakeSubgraphWalker
(
const
common
::
TopoWalker
<
NodeT
>&
walker
,
IterT
src_begin
,
IterT
src_end
,
IterT
sink_begin
,
IterT
sink_end
)
{
common
::
TopoWalker
<
NodeT
>
reversed_walker
(
walker
.
VisitNextNodes
,
walker
.
VisitPrevNodes
);
auto
ReachableToOneSrc
=
common
::
MakeIsReachableFromSrcPredicator
<
NodeT
,
IterT
>
(
walker
,
src_begin
,
src_end
);
auto
ReachableToOneSink
=
common
::
MakeIsReachableFromSrcPredicator
<
NodeT
,
IterT
>
(
reversed_walker
,
sink_begin
,
sink_end
);
auto
VisitPrevNodes
=
[
ReachableToOneSrc
,
ReachableToOneSink
,
walker
](
NodeT
node
,
const
std
::
function
<
void
(
NodeT
)
>&
Visitor
)
{
walker
.
VisitPrevNodes
(
node
,
[
&
](
NodeT
in_node
)
{
if
(
ReachableToOneSrc
(
in_node
)
&&
ReachableToOneSink
(
in_node
))
{
Visitor
(
in_node
);
}
});
};
auto
VisitNextNodes
=
[
ReachableToOneSrc
,
ReachableToOneSink
,
walker
](
NodeT
node
,
const
std
::
function
<
void
(
NodeT
)
>&
Visitor
)
{
walker
.
VisitNextNodes
(
node
,
[
&
](
NodeT
out_node
)
{
if
(
ReachableToOneSrc
(
out_node
)
&&
ReachableToOneSink
(
out_node
))
{
Visitor
(
out_node
);
}
});
};
return
common
::
TopoWalker
<
NodeT
>
(
VisitPrevNodes
,
VisitNextNodes
);
}
}
// namespace cinn::common
paddle/cinn/common/nvgpu_dev_info.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/common/nvgpu_dev_info.h"
namespace
cinn
{
namespace
common
{
std
::
array
<
int
,
3
>
NVGPUDevInfo
::
GetMaxGridDims
()
const
{
std
::
array
<
int
,
3
>
ret
;
ret
[
0
]
=
prop_
.
maxGridSize
[
0
];
ret
[
1
]
=
prop_
.
maxGridSize
[
1
];
ret
[
2
]
=
prop_
.
maxGridSize
[
2
];
return
ret
;
}
std
::
array
<
int
,
3
>
NVGPUDevInfo
::
GetMaxBlockDims
()
const
{
std
::
array
<
int
,
3
>
ret
;
ret
[
0
]
=
prop_
.
maxThreadsDim
[
0
];
ret
[
1
]
=
prop_
.
maxThreadsDim
[
1
];
ret
[
2
]
=
prop_
.
maxThreadsDim
[
2
];
return
ret
;
}
int
NVGPUDevInfo
::
GetMultiProcessorCount
()
const
{
return
prop_
.
multiProcessorCount
;
}
int
NVGPUDevInfo
::
GetMaxThreadsPerMultiProcessor
()
const
{
return
prop_
.
maxThreadsPerMultiProcessor
;
}
int
NVGPUDevInfo
::
GetMaxThreadsPerBlock
()
const
{
return
prop_
.
maxThreadsPerBlock
;
}
}
// namespace common
}
// namespace cinn
#endif
paddle/cinn/common/nvgpu_dev_info.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef CINN_WITH_CUDA
#include <ostream>
#include <string>
#include <vector>
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/dev_info_base.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/common/target.h"
namespace
cinn
{
namespace
common
{
class
NVGPUDevInfo
:
public
DevInfoBase
{
public:
explicit
NVGPUDevInfo
(
int
device_num
=
0
)
:
DevInfoBase
(
device_num
)
{
CUDA_CALL
(
cudaGetDeviceProperties
(
&
prop_
,
device_num
));
}
std
::
array
<
int
,
3
>
GetMaxGridDims
()
const
;
std
::
array
<
int
,
3
>
GetMaxBlockDims
()
const
;
int
GetMultiProcessorCount
()
const
;
int
GetMaxThreadsPerMultiProcessor
()
const
;
int
GetMaxThreadsPerBlock
()
const
;
private:
cudaDeviceProp
prop_
;
};
}
// namespace common
}
// namespace cinn
#endif
paddle/cinn/common/target.cc
View file @
01a10755
...
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef CINN_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <driver_types.h>
#endif
...
...
@@ -20,12 +21,20 @@
#include <sstream>
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
namespace
cinn
{
namespace
common
{
Target
::
Target
(
OS
o
,
Arch
a
,
Bit
b
,
const
std
::
vector
<
Feature
>
&
features
,
const
std
::
vector
<
Lib
>
&
libs
)
:
os
(
o
),
arch
(
a
),
bits
(
b
),
features
(
features
),
libs
(
libs
)
{}
bool
Target
::
operator
==
(
const
Target
&
other
)
const
{
return
os
==
other
.
os
&&
//
arch
==
other
.
arch
&&
//
...
...
paddle/cinn/common/target.h
View file @
01a10755
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <array>
#include <ostream>
#include <string>
#include <vector>
...
...
@@ -71,8 +72,7 @@ struct Target {
Arch
a
=
Arch
::
Unk
,
Bit
b
=
Bit
::
Unk
,
const
std
::
vector
<
Feature
>&
features
=
{},
const
std
::
vector
<
Lib
>&
libs
=
{})
:
os
(
o
),
arch
(
a
),
bits
(
b
),
features
(
features
),
libs
(
libs
)
{}
const
std
::
vector
<
Lib
>&
libs
=
{});
bool
defined
()
const
{
return
os
!=
OS
::
Unk
&&
arch
!=
Arch
::
Unk
&&
bits
!=
Bit
::
Unk
;
...
...
paddle/cinn/common/topo_walker.h
View file @
01a10755
...
...
@@ -26,16 +26,17 @@ namespace common {
template
<
typename
NodeType
>
class
TopoWalker
final
{
public:
TopoWalker
(
const
TopoWalker
&
)
=
de
lete
;
TopoWalker
(
TopoWalker
&&
)
=
de
lete
;
TopoWalker
(
const
TopoWalker
&
)
=
de
fault
;
TopoWalker
(
TopoWalker
&&
)
=
de
fault
;
using
NodeHandlerType
=
std
::
function
<
void
(
NodeType
)
>
;
using
NodesVisitorType
=
std
::
function
<
void
(
NodeType
,
const
NodeHandlerType
&
)
>
;
TopoWalker
(
const
NodesVisitorType
&
VisitPrevNodes
,
const
NodesVisitorType
&
VisitNextNodes
)
:
VisitPrevNodes_
(
VisitPrevNodes
),
VisitNextNodes_
(
VisitNextNodes
)
{}
TopoWalker
(
const
NodesVisitorType
&
VisitPrevNodesValue
,
const
NodesVisitorType
&
VisitNextNodesValue
)
:
VisitPrevNodes
(
VisitPrevNodesValue
),
VisitNextNodes
(
VisitNextNodesValue
)
{}
void
operator
()(
NodeType
node
,
const
NodeHandlerType
&
NodeHandler
)
const
{
std
::
array
<
NodeType
,
1
>
nodes
{
node
};
...
...
@@ -61,9 +62,9 @@ class TopoWalker final {
NodeType
node
=
node_queue
.
front
();
node_queue
.
pop
();
NodeHandler
(
node
);
VisitNextNodes
_
(
node
,
[
&
](
NodeType
node
)
{
VisitNextNodes
(
node
,
[
&
](
NodeType
node
)
{
size_t
num_unfinished_inputs
=
0
;
VisitPrevNodes
_
(
node
,
[
&
](
NodeType
in_node
)
{
VisitPrevNodes
(
node
,
[
&
](
NodeType
in_node
)
{
num_unfinished_inputs
+=
(
queued_nodes
.
count
(
in_node
)
>
0
?
0
:
1
);
});
if
(
num_unfinished_inputs
==
0
)
{
...
...
@@ -73,9 +74,8 @@ class TopoWalker final {
}
}
private:
NodesVisitorType
VisitPrevNodes_
;
NodesVisitorType
VisitNextNodes_
;
NodesVisitorType
VisitPrevNodes
;
NodesVisitorType
VisitNextNodes
;
};
}
// namespace common
...
...
paddle/cinn/frontend/computation.cc
View file @
01a10755
...
...
@@ -72,16 +72,18 @@ std::shared_ptr<ComputationContext> CompileProgram(
}
ctx
->
scope
=
hlir
::
framework
::
BuildScope
(
target
,
ctx
->
graph
,
scope
);
ctx
->
graph_compiler
.
reset
(
new
hlir
::
framework
::
GraphCompiler
(
target
,
ctx
->
scope
,
ctx
->
graph
));
std
::
unordered_set
<
std
::
string
>
fetch_var_ids
;
for
(
auto
&
out
:
outputs
)
{
fetch_var_ids
.
insert
(
out
->
id
);
}
ctx
->
program
=
ctx
->
graph_compiler
->
Build
(
options
,
std
::
move
(
fetch_var_ids
))
.
runtime_program
;
ctx
->
compile_options
.
graph
=
ctx
->
graph
;
ctx
->
compile_options
.
scope
=
ctx
->
scope
;
ctx
->
compile_options
.
fetch_var_ids
=
fetch_var_ids
;
ctx
->
graph_compiler
.
reset
(
new
hlir
::
framework
::
GraphCompiler
(
ctx
->
compile_options
));
ctx
->
program
=
ctx
->
graph_compiler
->
Build
();
if
(
ctx
->
compile_options
.
do_prerun
)
{
ctx
->
program
->
PreRun
();
}
...
...
paddle/cinn/frontend/computation.h
View file @
01a10755
...
...
@@ -27,8 +27,7 @@ struct ComputationContext;
class
CinnComputation
{
public:
struct
CompileOptions
:
public
hlir
::
framework
::
GraphCompiler
::
CompileOptions
{
struct
CompileOptions
:
public
hlir
::
framework
::
CompilationContext
{
bool
use_decomposer
=
false
;
bool
do_prerun
=
true
;
bool
use_default_passes
=
true
;
...
...
paddle/cinn/frontend/computation_test.cc
View file @
01a10755
...
...
@@ -23,7 +23,7 @@
#include "paddle/cinn/frontend/pass/use_program_pass.h"
#include "paddle/cinn/frontend/program_pass.h"
DEFINE_string
(
model_dir
,
""
,
""
);
PD_
DEFINE_string
(
model_dir
,
""
,
""
);
namespace
cinn
{
namespace
frontend
{
...
...
paddle/cinn/frontend/decomposer/activation_test.cc
View file @
01a10755
...
...
@@ -86,7 +86,8 @@ TEST(Decomposer, softmax_decomposer) {
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"FusionMergePass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
run_program
=
gc
.
Build
();
std
::
vector
<
float
>
x
(
n
*
c
*
h
*
w
);
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
29
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment