Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
565
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
510 additions
and
40 deletions
+510
-40
paddle/cinn/hlir/framework/pir/utils.h
paddle/cinn/hlir/framework/pir/utils.h
+83
-0
paddle/cinn/hlir/framework/pir_compiler.cc
paddle/cinn/hlir/framework/pir_compiler.cc
+220
-0
paddle/cinn/hlir/framework/pir_compiler.h
paddle/cinn/hlir/framework/pir_compiler.h
+99
-0
paddle/cinn/hlir/framework/visualize_helper.cc
paddle/cinn/hlir/framework/visualize_helper.cc
+2
-2
paddle/cinn/hlir/op/broadcast.cc
paddle/cinn/hlir/op/broadcast.cc
+64
-15
paddle/cinn/hlir/op/contrib/argmax.cc
paddle/cinn/hlir/op/contrib/argmax.cc
+19
-0
paddle/cinn/hlir/op/contrib/argmin.cc
paddle/cinn/hlir/op/contrib/argmin.cc
+20
-0
paddle/cinn/hlir/op/contrib/bitcast_convert.cc
paddle/cinn/hlir/op/contrib/bitcast_convert.cc
+0
-2
paddle/cinn/hlir/op/contrib/cholesky.cc
paddle/cinn/hlir/op/contrib/cholesky.cc
+0
-2
paddle/cinn/hlir/op/contrib/gather_nd.cc
paddle/cinn/hlir/op/contrib/gather_nd.cc
+0
-2
paddle/cinn/hlir/op/contrib/gaussian_random.cc
paddle/cinn/hlir/op/contrib/gaussian_random.cc
+0
-2
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
+1
-1
paddle/cinn/hlir/op/contrib/lookup_table.cc
paddle/cinn/hlir/op/contrib/lookup_table.cc
+1
-1
paddle/cinn/hlir/op/contrib/one_hot.cc
paddle/cinn/hlir/op/contrib/one_hot.cc
+0
-2
paddle/cinn/hlir/op/contrib/randint.cc
paddle/cinn/hlir/op/contrib/randint.cc
+0
-2
paddle/cinn/hlir/op/contrib/reciprocal.cc
paddle/cinn/hlir/op/contrib/reciprocal.cc
+1
-1
paddle/cinn/hlir/op/contrib/repeat.cc
paddle/cinn/hlir/op/contrib/repeat.cc
+0
-2
paddle/cinn/hlir/op/contrib/resize.cc
paddle/cinn/hlir/op/contrib/resize.cc
+0
-2
paddle/cinn/hlir/op/contrib/sort.cc
paddle/cinn/hlir/op/contrib/sort.cc
+0
-2
paddle/cinn/hlir/op/contrib/uniform_random.cc
paddle/cinn/hlir/op/contrib/uniform_random.cc
+0
-2
No files found.
Too many changes to show.
To preserve performance only
565 of 565+
files are displayed.
Plain diff
Email patch
paddle/cinn/hlir/framework/pir/utils.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/pir/core/operation.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
namespace
pir
{
struct
CUDAJITInfo
{
void
*
fn_ptr
;
std
::
vector
<
int
>
block_dims
;
std
::
vector
<
int
>
grid_dims
;
void
*
compiler
;
};
struct
CompatibleInfo
{
static
constexpr
char
*
kNamePrefix
=
"var"
;
// TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP
// macros or attempt to unify Op name with Paddle and CINN.
static
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
OP_NAMES
;
// NOTE(Aurelius): Some ops in CINN register different
// name between OpMapper and Compute/Schedule, such as
// 'subtract': 1. OpMapper: 'elementwise_sub'; 2. Compute/Schedule:
// 'subtract'.
static
const
std
::
unordered_set
<
std
::
string
>
CINN_WHITE_OPS
;
static
bool
IsSupportCinn
(
const
::
pir
::
Operation
&
op
);
static
std
::
string
OpName
(
const
::
pir
::
Operation
&
op
);
static
std
::
string
ValueName
(
const
::
pir
::
Value
&
value
);
static
std
::
string
OpFuncName
(
const
::
pir
::
Operation
&
op
);
static
std
::
string
GroupOpsName
(
const
std
::
vector
<::
pir
::
Operation
*>&
ops
);
static
std
::
vector
<
std
::
string
>
InputNames
(
const
::
pir
::
Operation
&
op
,
bool
allow_duplicate
=
false
);
static
std
::
vector
<
std
::
string
>
OutputNames
(
::
pir
::
Operation
&
op
);
// NOLINT
static
std
::
vector
<::
pir
::
Value
>
RealOperandSources
(
const
::
pir
::
Operation
&
op
);
static
utils
::
Attribute
ConvertAttribute
(
const
::
pir
::
Attribute
&
src_attr
);
static
utils
::
AttributeMap
ConvertAttributes
(
const
::
pir
::
Operation
&
op
);
static
common
::
Type
ConvertIRType
(
::
pir
::
Type
type
);
static
std
::
vector
<
int
>
ValueShape
(
const
::
pir
::
Value
&
value
);
static
int
ShapeProduct
(
const
std
::
vector
<
int
>&
shape
);
static
OpPatternKind
OpKind
(
const
::
pir
::
Operation
&
op
);
};
}
// namespace pir
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/pir_compiler.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/pir_compiler.h"
#include <absl/types/variant.h>
#include "paddle/cinn/hlir/framework/pir/compilation_task.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/utils/multi_threading.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/pir/core/builtin_type.h"
PD_DECLARE_bool
(
cinn_bucket_compile
);
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
// TODO(Aurelius84): Clear usless Build Interface.
std
::
unique_ptr
<
Program
>
PirCompiler
::
Build
()
{
m_builder_
.
Clear
();
// NOTE(Aurelius84): Currently only support each op for one group
std
::
vector
<
pir
::
GroupPtr
>
groups
;
for
(
auto
&
op
:
*
program_
.
block
())
{
std
::
vector
<::
pir
::
Operation
*>
ops
=
{
&
op
};
auto
group
=
std
::
make_shared
<
pir
::
Group
>
(
ops
);
group
->
output_ops
.
insert
(
&
op
);
groups
.
push_back
(
group
);
}
VLOG
(
4
)
<<
"Groups size: "
<<
groups
.
size
();
return
std
::
move
(
Build
(
groups
));
}
std
::
vector
<
pir
::
CUDAJITInfo
>
PirCompiler
::
BuildCUDAJITInfo
(
const
std
::
vector
<
pir
::
GroupPtr
>&
groups
)
{
std
::
vector
<
pir
::
CUDAJITInfo
>
vec_res
;
auto
op_lowerer
=
CreateOpLowerer
<
pir
::
GroupPtr
>
(
target_
);
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
lowered_funcs
;
for
(
int
i
=
0
;
i
<
groups
.
size
();
++
i
)
{
lowered_funcs
.
emplace_back
(
op_lowerer
.
Lower
(
groups
[
i
]));
}
for
(
auto
&&
lowered_func
:
lowered_funcs
)
{
ProcessFunction
(
lowered_func
);
}
compiler_
=
backends
::
Compiler
::
Create
(
target_
);
auto
build_module
=
m_builder_
.
Build
();
compiler_
->
Build
(
build_module
,
""
);
auto
instructions
=
BuildInstructions
(
groups
);
auto
fn_ptrs
=
compiler_
->
GetFnPtr
();
auto
*
compilter_ptr
=
compiler_
.
release
();
for
(
int
idx
=
0
;
idx
<
groups
.
size
();
++
idx
)
{
pir
::
CUDAJITInfo
jit_info
;
jit_info
.
fn_ptr
=
fn_ptrs
[
idx
];
jit_info
.
compiler
=
reinterpret_cast
<
void
*>
(
compilter_ptr
);
lowered_funcs
[
idx
][
0
]
->
cuda_axis_info
.
CopyBlockDimsTo
(
&
(
jit_info
.
block_dims
));
lowered_funcs
[
idx
][
0
]
->
cuda_axis_info
.
CopyGridDimsTo
(
&
(
jit_info
.
grid_dims
));
vec_res
.
push_back
(
jit_info
);
}
return
vec_res
;
}
std
::
unique_ptr
<
Program
>
PirCompiler
::
Build
(
const
std
::
vector
<
pir
::
GroupPtr
>&
groups
)
{
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>
instructions
(
groups
.
size
());
if
(
FLAGS_cinn_bucket_compile
)
{
for
(
int
i
=
0
;
i
<
groups
.
size
();
++
i
)
{
group_compilation_contexts_
.
emplace_back
(
target_
,
groups
[
i
],
scope_
);
}
auto
worker_fn
=
[
&
](
int
index
)
{
CompilationTask
task
(
&
group_compilation_contexts_
[
index
]);
task
();
instructions
[
index
]
=
task
.
BuildInstruction
();
};
utils
::
parallel_run
(
worker_fn
,
utils
::
SequenceDispatcher
(
0
,
groups
.
size
()),
-
1
);
}
else
{
auto
op_lowerer
=
CreateOpLowerer
<
pir
::
GroupPtr
>
(
target_
);
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
lowered_funcs
;
for
(
int
i
=
0
;
i
<
groups
.
size
();
++
i
)
{
lowered_funcs
.
emplace_back
(
op_lowerer
.
Lower
(
groups
[
i
]));
}
for
(
auto
&&
lowered_func
:
lowered_funcs
)
{
ProcessFunction
(
lowered_func
);
}
compiler_
=
backends
::
Compiler
::
Create
(
target_
);
auto
build_module
=
m_builder_
.
Build
();
compiler_
->
Build
(
build_module
,
""
);
instructions
=
BuildInstructions
(
groups
);
}
// TODO(Aurelius84): Instantiate all tensors on compile-time, which is
// controlled by 'options.with_instantiate_variables' in GraphCompiler.
// Moreover, it's better to implement InsertBufferHandlers() logic
// to automatically insert Malloc and Free instructions.
for
(
auto
&
name
:
scope_
->
var_names
())
{
std
::
string
var_name
({
name
.
data
(),
name
.
size
()});
VLOG
(
4
)
<<
"Instantiate "
<<
var_name
<<
" on compile-time"
;
auto
*
var
=
scope_
->
Var
<
Tensor
>
(
var_name
);
auto
&
tensor
=
absl
::
get
<
Tensor
>
(
*
var
);
tensor
->
mutable_data
(
target_
,
tensor
->
type
());
}
return
std
::
make_unique
<
Program
>
(
scope_
,
std
::
move
(
instructions
));
}
void
PirCompiler
::
ProcessFunction
(
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
)
{
for
(
auto
&&
func
:
lowered_funcs
)
{
for
(
auto
&&
arg
:
func
->
args
)
{
std
::
string
arg_name
=
arg
.
name
();
if
(
arg_name
[
0
]
==
'_'
)
arg_name
=
arg_name
.
substr
(
1
);
auto
*
var
=
scope_
->
FindVar
(
arg_name
);
// For argument buffer not in scope, create it.
if
(
!
var
&&
arg
.
is_buffer
())
{
auto
*
new_var
=
scope_
->
Var
<
Tensor
>
(
arg_name
);
auto
&
tensor
=
absl
::
get
<
Tensor
>
(
*
new_var
);
std
::
vector
<
Shape
::
dim_t
>
shape
;
for
(
auto
&
shape_dim
:
arg
.
buffer_arg
()
->
shape
)
{
CHECK
(
shape_dim
.
is_constant
());
shape
.
push_back
(
static_cast
<
int
>
(
shape_dim
.
get_constant
()));
}
tensor
->
Resize
(
Shape
{
shape
});
tensor
->
set_type
(
arg
.
buffer_arg
()
->
dtype
);
}
}
m_builder_
.
AddFunction
(
func
);
}
}
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>
PirCompiler
::
BuildInstructions
(
const
std
::
vector
<
pir
::
GroupPtr
>&
groups
)
{
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>
instructions
;
for
(
int
idx
=
0
;
idx
<
groups
.
size
();
++
idx
)
{
auto
fn_name
=
groups
[
idx
]
->
FuncName
();
auto
instr
=
std
::
unique_ptr
<
Instruction
>
(
new
Instruction
(
target_
,
scope_
.
get
(),
groups
[
idx
]
->
input_names
,
groups
[
idx
]
->
output_names
,
fn_name
));
VLOG
(
4
)
<<
"Lookup kernel name: "
<<
fn_name
;
auto
*
fn_ptr
=
compiler_
->
Lookup
(
fn_name
);
CHECK
(
fn_ptr
);
instr
->
SetLoweredFunc
(
reinterpret_cast
<
void
*>
(
fn_ptr
),
fn_name
);
// As some instruction like reduce, will generate more than one kernel.
// So try to find the rest kernel, if it exists.
// SetSubKernels(instr.get(), fn_name);
instr
->
Finalize
();
instructions
.
push_back
(
std
::
move
(
instr
));
}
return
instructions
;
}
std
::
shared_ptr
<
Scope
>
BuildScope
(
const
Target
&
target
,
const
::
pir
::
Program
&
program
)
{
std
::
unordered_set
<::
pir
::
Value
>
visited
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
create_var
=
[
&
](
::
pir
::
Value
value
)
{
if
(
!
(
value
)
||
!
(
value
.
type
()))
{
return
;
}
if
(
visited
.
count
(
value
)
>
0
)
return
;
visited
.
emplace
(
value
);
std
::
string
name
=
pir
::
CompatibleInfo
::
ValueName
(
value
);
auto
type_info
=
value
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
();
auto
*
var
=
scope
->
Var
<
Tensor
>
(
name
);
auto
&
tensor
=
absl
::
get
<
Tensor
>
(
*
var
);
std
::
vector
<
Shape
::
dim_t
>
shape
;
for
(
auto
i
=
0
;
i
<
type_info
.
dims
().
size
();
++
i
)
{
shape
.
push_back
(
Shape
::
dim_t
(
type_info
.
dims
()[
i
]));
}
tensor
->
Resize
(
Shape
{
shape
});
tensor
->
set_type
(
pir
::
CompatibleInfo
::
ConvertIRType
(
type_info
.
dtype
()));
};
for
(
auto
&
op
:
*
program
.
block
())
{
for
(
auto
oprand
:
op
.
operands
())
{
create_var
(
oprand
.
source
());
}
for
(
auto
result
:
op
.
results
())
{
create_var
(
result
);
}
}
return
scope
;
}
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/pir_compiler.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include "paddle/cinn/common/macros.h"
#include "paddle/pir/core/program.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/hlir/framework/pir/compilation_task.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
// TODO(Aurelius84): Need abstract this logic to implement Proxy for
// the co-existance with GraphCompiler.
class
PirCompiler
final
{
public:
PirCompiler
(
const
::
pir
::
Program
&
prog
,
const
Target
&
target
,
const
std
::
shared_ptr
<
Scope
>&
scope
)
:
program_
(
prog
),
m_builder_
(
"Pir"
,
target
),
target_
(
target
),
scope_
(
scope
)
{}
std
::
unique_ptr
<
Program
>
Build
();
std
::
vector
<
pir
::
CUDAJITInfo
>
BuildCUDAJITInfo
(
const
std
::
vector
<
pir
::
GroupPtr
>&
groups
);
std
::
unique_ptr
<
Program
>
Build
(
const
std
::
vector
<
pir
::
GroupPtr
>&
groups
);
private:
CINN_DISALLOW_COPY_AND_ASSIGN
(
PirCompiler
);
std
::
vector
<
ir
::
LoweredFunc
>
GetOpFunc
(
const
::
pir
::
Operation
&
op
,
int
idx
);
void
ProcessFunction
(
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
);
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>
BuildInstructions
(
const
std
::
vector
<
pir
::
GroupPtr
>&
groups
);
const
::
pir
::
Program
&
program_
;
ir
::
Module
::
Builder
m_builder_
;
std
::
unique_ptr
<
backends
::
Compiler
>
compiler_
{
nullptr
};
Target
target_
;
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
func_names_
;
std
::
vector
<
GroupCompilationContext
>
group_compilation_contexts_
;
};
std
::
shared_ptr
<
Scope
>
BuildScope
(
const
Target
&
,
const
::
pir
::
Program
&
);
class
PirCompilerManager
{
public:
static
PirCompilerManager
&
Instance
()
{
static
PirCompilerManager
instance
;
return
instance
;
}
static
std
::
shared_ptr
<
PirCompiler
>
Create
(
const
::
pir
::
Program
&
prog
,
const
Target
&
target
,
const
std
::
shared_ptr
<
Scope
>&
scope
)
{
std
::
shared_ptr
<
PirCompiler
>
compiler
=
std
::
make_shared
<
PirCompiler
>
(
prog
,
target
,
scope
);
PirCompilerManager
::
Instance
().
insert
(
compiler
);
return
compiler
;
}
void
insert
(
const
std
::
shared_ptr
<
PirCompiler
>&
compiler
)
{
compilers_
.
push_back
(
compiler
);
}
void
clear
()
{
compilers_
.
clear
();
}
private:
std
::
vector
<
std
::
shared_ptr
<
PirCompiler
>>
compilers_
;
};
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/visualize_helper.cc
View file @
01a10755
...
...
@@ -30,8 +30,8 @@
#include "paddle/cinn/utils/dot_lang.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string
(
cinn_pass_visualize_dir
);
DECLARE_string
(
cinn_check_fusion_accuracy_pass
);
PD_
DECLARE_string
(
cinn_pass_visualize_dir
);
PD_
DECLARE_string
(
cinn_check_fusion_accuracy_pass
);
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
...
...
paddle/cinn/hlir/op/broadcast.cc
View file @
01a10755
...
...
@@ -16,6 +16,7 @@
#include <iostream>
#include "paddle/cinn/adt/op_equation_context.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
...
...
@@ -124,6 +125,32 @@ std::vector<Type> InferDtypeForBroadcast(const std::vector<Type> &inputs_type,
return
res
;
}
void
GenerateEquationsForBroadcast
(
cinn
::
adt
::
config
::
OpEquationContext
*
ctx
)
{
CHECK
(
ctx
->
GetInTensorsRanks
().
size
()
==
2
)
<<
"The inputs is "
<<
ctx
->
GetInTensorsRanks
().
size
()
<<
"! Please check again."
;
CHECK
(
ctx
->
GetOutTensorsRanks
().
size
()
==
1
)
<<
"The output is "
<<
ctx
->
GetOutTensorsRanks
().
size
()
<<
"! Please check again."
;
std
::
uint64_t
out_tensor_ranks
=
ctx
->
GetOutTensorsRanks
().
at
(
0
);
std
::
uint64_t
in_tensor0_ranks
=
ctx
->
GetInTensorsRanks
().
at
(
0
);
std
::
uint64_t
in_tensor1_ranks
=
ctx
->
GetInTensorsRanks
().
at
(
1
);
int
offset0
=
out_tensor_ranks
-
in_tensor0_ranks
;
for
(
std
::
size_t
i
=
0
;
i
<
in_tensor0_ranks
;
++
i
)
{
ctx
->
Equal
(
ctx
->
GetInIteratorTuple
(
0
)
->
at
(
i
),
ctx
->
GetBroadcastedInputIterator
(
ctx
->
GetOutIteratorTuple
(
0
)
->
at
(
i
+
offset0
),
ctx
->
GetInDimTuple
(
0
)
->
at
(
i
)));
}
int
offset1
=
out_tensor_ranks
-
in_tensor1_ranks
;
for
(
std
::
size_t
i
=
0
;
i
<
in_tensor1_ranks
;
++
i
)
{
ctx
->
Equal
(
ctx
->
GetInIteratorTuple
(
1
)
->
at
(
i
),
ctx
->
GetBroadcastedInputIterator
(
ctx
->
GetOutIteratorTuple
(
0
)
->
at
(
i
+
offset1
),
ctx
->
GetInDimTuple
(
1
)
->
at
(
i
)));
}
}
std
::
vector
<
Type
>
InferDtypeForBroadcastCmp
(
const
std
::
vector
<
Type
>
&
inputs_type
,
const
framework
::
AttrMapType
&
attrs
)
{
CHECK
(
!
inputs_type
.
empty
())
...
...
@@ -242,6 +269,24 @@ std::vector<shape_t> InferShapeForBroadcastTo(
return
{
out_shape
};
}
void
GenerateEquationsForBroadcastTo
(
cinn
::
adt
::
config
::
OpEquationContext
*
ctx
)
{
CHECK
(
ctx
->
GetInTensorsRanks
().
size
()
==
1
)
<<
"The inputs is "
<<
ctx
->
GetInTensorsRanks
().
size
()
<<
"! Please check again."
;
CHECK
(
ctx
->
GetOutTensorsRanks
().
size
()
==
1
)
<<
"The output is "
<<
ctx
->
GetOutTensorsRanks
().
size
()
<<
"! Please check again."
;
std
::
size_t
out_tensor_rank
=
ctx
->
GetOutTensorsRanks
().
at
(
0
);
int
start_axis
=
out_tensor_rank
-
ctx
->
GetInTensorsRanks
().
at
(
0
);
for
(
std
::
size_t
i
=
start_axis
;
i
<
out_tensor_rank
;
++
i
)
{
ctx
->
Equal
(
ctx
->
GetInIteratorTuple
(
0
)
->
at
(
i
-
start_axis
),
ctx
->
GetBroadcastedInputIterator
(
ctx
->
GetOutIteratorTuple
(
0
)
->
at
(
i
),
ctx
->
GetInDimTuple
(
0
)
->
at
(
i
-
start_axis
)));
}
}
std
::
vector
<
std
::
vector
<
std
::
string
>>
InferLayoutForBroadcastTo
(
const
std
::
vector
<
std
::
vector
<
int
>>
&
input_shapes
,
const
std
::
vector
<
std
::
string
>
&
input_layouts
,
...
...
@@ -412,6 +457,8 @@ CINN_REGISTER_HELPER(broadcast_ops) {
MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) \
.set_attr("inferdtype", \
MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcast)) \
.set_attr("generate_equations", \
MakeOpFunction(cinn::hlir::op::GenerateEquationsForBroadcast)) \
.set_attr("inferlayout", \
MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcast)) \
.set_attr<cinn::hlir::framework::OpPatternKind>( \
...
...
@@ -476,6 +523,8 @@ CINN_REGISTER_HELPER(broadcast_ops) {
MakeOpFunction
(
cinn
::
hlir
::
op
::
InferShapeForBroadcastTo
))
.
set_attr
(
"inferdtype"
,
MakeOpFunction
(
cinn
::
hlir
::
op
::
InferDtypeForBroadcast
))
.
set_attr
(
"generate_equations"
,
MakeOpFunction
(
cinn
::
hlir
::
op
::
GenerateEquationsForBroadcastTo
))
#ifndef CINN_WITH_CUDA
.
set_attr
(
"inferlayout"
,
MakeOpFunction
(
cinn
::
hlir
::
op
::
InferLayoutForBroadcastTo
))
...
...
paddle/cinn/hlir/op/contrib/argmax.cc
View file @
01a10755
...
...
@@ -161,6 +161,25 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgmax(
ir_sch
.
SetBuffer
(
blocks
[
0
],
"local"
);
ir_sch
.
SetBuffer
(
blocks
[
1
],
"local"
);
int
iter_var_size
=
blocks
[
0
]
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
iter_vars
.
size
();
int
real_axis
=
axis
;
if
(
real_axis
<
0
)
{
real_axis
+=
iter_var_size
;
}
blocks
[
0
]
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
iter_vars
[
real_axis
]
->
is_reduce_axis
=
true
;
blocks
[
1
]
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
iter_vars
[
real_axis
]
->
is_reduce_axis
=
true
;
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
1
,
...
...
paddle/cinn/hlir/op/contrib/argmin.cc
View file @
01a10755
...
...
@@ -158,6 +158,26 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgmin(
// variables, because the size will exceed the limit.
ir_sch
.
SetBuffer
(
blocks
[
0
],
"local"
);
ir_sch
.
SetBuffer
(
blocks
[
1
],
"local"
);
int
iter_var_size
=
blocks
[
0
]
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
iter_vars
.
size
();
int
real_axis
=
axis
;
if
(
real_axis
<
0
)
{
real_axis
+=
iter_var_size
;
}
blocks
[
0
]
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
iter_vars
[
real_axis
]
->
is_reduce_axis
=
true
;
blocks
[
1
]
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
iter_vars
[
real_axis
]
->
is_reduce_axis
=
true
;
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
1
,
...
...
paddle/cinn/hlir/op/contrib/bitcast_convert.cc
View file @
01a10755
...
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <memory>
#include <string>
#include <utility>
...
...
paddle/cinn/hlir/op/contrib/cholesky.cc
View file @
01a10755
...
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <memory>
#include <string>
#include <utility>
...
...
paddle/cinn/hlir/op/contrib/gather_nd.cc
View file @
01a10755
...
...
@@ -14,8 +14,6 @@
#include "paddle/cinn/hlir/op/contrib/gather_nd.h"
#include <gflags/gflags.h>
#include <memory>
#include <string>
#include <utility>
...
...
paddle/cinn/hlir/op/contrib/gaussian_random.cc
View file @
01a10755
...
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <memory>
#include <string>
#include <utility>
...
...
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
View file @
01a10755
...
...
@@ -17,7 +17,6 @@
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/context.h"
...
...
@@ -37,6 +36,7 @@
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/utils/flags.h"
namespace
cinn
{
namespace
hlir
{
...
...
paddle/cinn/hlir/op/contrib/lookup_table.cc
View file @
01a10755
...
...
@@ -19,7 +19,6 @@
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/context.h"
...
...
@@ -38,6 +37,7 @@
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/utils/flags.h"
namespace
cinn
{
namespace
hlir
{
...
...
paddle/cinn/hlir/op/contrib/one_hot.cc
View file @
01a10755
...
...
@@ -14,8 +14,6 @@
#include "paddle/cinn/hlir/op/contrib/one_hot.h"
#include <gflags/gflags.h>
#include <memory>
#include <string>
#include <utility>
...
...
paddle/cinn/hlir/op/contrib/randint.cc
View file @
01a10755
...
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <memory>
#include <string>
#include <utility>
...
...
paddle/cinn/hlir/op/contrib/reciprocal.cc
View file @
01a10755
...
...
@@ -17,7 +17,6 @@
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/context.h"
...
...
@@ -37,6 +36,7 @@
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/utils/flags.h"
namespace
cinn
{
namespace
hlir
{
...
...
paddle/cinn/hlir/op/contrib/repeat.cc
View file @
01a10755
...
...
@@ -14,8 +14,6 @@
#include "paddle/cinn/hlir/op/contrib/repeat.h"
#include <gflags/gflags.h>
#include <memory>
#include <string>
#include <utility>
...
...
paddle/cinn/hlir/op/contrib/resize.cc
View file @
01a10755
...
...
@@ -14,8 +14,6 @@
#include "paddle/cinn/hlir/op/contrib/resize.h"
#include <gflags/gflags.h>
#include <memory>
#include <string>
#include <utility>
...
...
paddle/cinn/hlir/op/contrib/sort.cc
View file @
01a10755
...
...
@@ -14,8 +14,6 @@
#include "paddle/cinn/hlir/op/contrib/sort.h"
#include <gflags/gflags.h>
#include <memory>
#include <string>
#include <utility>
...
...
paddle/cinn/hlir/op/contrib/uniform_random.cc
View file @
01a10755
...
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <memory>
#include <string>
#include <utility>
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
…
29
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment