Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
549416f7
Commit
549416f7
authored
Jan 11, 2025
by
LeiWang1999
Browse files
Merge branch 'main' of
https://github.com/microsoft/TileLang
into main
parents
4d63633a
7fad4e88
Changes
90
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
11122 additions
and
8555 deletions
+11122
-8555
src/op/gemm.h
src/op/gemm.h
+9
-9
src/op/op.cc
src/op/op.cc
+18
-10
src/op/op.h
src/op/op.h
+27
-24
src/op/parallel.cc
src/op/parallel.cc
+66
-45
src/op/parallel.h
src/op/parallel.h
+14
-14
src/op/reduce.cc
src/op/reduce.cc
+84
-70
src/op/reduce.h
src/op/reduce.h
+9
-9
src/runtime/runtime.cc
src/runtime/runtime.cc
+68
-48
src/runtime/runtime.h
src/runtime/runtime.h
+7
-5
src/target/codegen_cuda.cc
src/target/codegen_cuda.cc
+472
-379
src/target/codegen_cuda.h
src/target/codegen_cuda.h
+50
-40
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+377
-307
src/target/codegen_hip.h
src/target/codegen_hip.h
+42
-34
src/target/cuda.h
src/target/cuda.h
+9756
-7470
src/target/rt_mod_cuda.cc
src/target/rt_mod_cuda.cc
+21
-14
src/target/rt_mod_hip.cc
src/target/rt_mod_hip.cc
+42
-31
src/target/utils.cc
src/target/utils.cc
+23
-12
src/target/utils.h
src/target/utils.h
+4
-4
src/tl_templates/cuda/common.h
src/tl_templates/cuda/common.h
+21
-20
src/tl_templates/cuda/copy.h
src/tl_templates/cuda/copy.h
+12
-10
No files found.
src/op/gemm.h
View file @
549416f7
...
@@ -18,18 +18,18 @@ namespace tl {
...
@@ -18,18 +18,18 @@ namespace tl {
using
namespace
tir
;
using
namespace
tir
;
class
Gemm
:
public
Operator
{
class
Gemm
:
public
Operator
{
public:
public:
Gemm
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
Gemm
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
final
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
final
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
static
const
Op
&
Get
();
static
const
Op
&
Get
();
enum
class
GemmWarpPolicy
{
enum
class
GemmWarpPolicy
{
kSquare
=
0
,
kSquare
=
0
,
kFullRow
=
1
,
kFullRow
=
1
,
kFullCol
=
2
,
kFullCol
=
2
,
}
policy
;
}
policy
;
private:
private:
std
::
pair
<
int
,
int
>
ComputeWarpPartition
(
int
num_warps
,
Target
target
)
const
;
std
::
pair
<
int
,
int
>
ComputeWarpPartition
(
int
num_warps
,
Target
target
)
const
;
Array
<
PrimExpr
>
call_args
;
Array
<
PrimExpr
>
call_args
;
...
@@ -38,11 +38,11 @@ class Gemm : public Operator {
...
@@ -38,11 +38,11 @@ class Gemm : public Operator {
int
M
,
N
,
K
;
int
M
,
N
,
K
;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
// only will be enabled under cdna mfma instructions
int
kPack
=
1
;
int
kPack
=
1
;
bool
completed_
=
false
;
bool
completed_
=
false
;
};
};
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif // TVM_TL_OP_GEMM_H_
#endif // TVM_TL_OP_GEMM_H_
\ No newline at end of file
\ No newline at end of file
src/op/op.cc
View file @
549416f7
...
@@ -20,13 +20,14 @@ using namespace tir;
...
@@ -20,13 +20,14 @@ using namespace tir;
TIR_REGISTER_TL_OP
(
RegionOp
,
region
)
TIR_REGISTER_TL_OP
(
RegionOp
,
region
)
.
set_num_inputs
(
-
1
)
.
set_num_inputs
(
-
1
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kPure
));
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kPure
));
std
::
unique_ptr
<
Operator
>
ParseOperator
(
Call
call
,
BufferMap
vmap
)
{
std
::
unique_ptr
<
Operator
>
ParseOperator
(
Call
call
,
BufferMap
vmap
)
{
auto
op_map
=
Op
::
GetAttrMap
<
OpBuilderFunc
>
(
"TLOpBuilder"
);
auto
op_map
=
Op
::
GetAttrMap
<
OpBuilderFunc
>
(
"TLOpBuilder"
);
Op
op
=
call
->
op
.
as
<
Op
>
().
value
();
Op
op
=
call
->
op
.
as
<
Op
>
().
value
();
if
(
op_map
.
count
(
op
))
{
if
(
op_map
.
count
(
op
))
{
Operator
*
ptr
=
static_cast
<
Operator
*>
(
op_map
[
op
](
call
->
args
,
vmap
));
Operator
*
ptr
=
static_cast
<
Operator
*>
(
op_map
[
op
](
call
->
args
,
vmap
));
ICHECK
(
ptr
!=
nullptr
);
ICHECK
(
ptr
!=
nullptr
);
return
std
::
unique_ptr
<
Operator
>
(
ptr
);
return
std
::
unique_ptr
<
Operator
>
(
ptr
);
}
}
...
@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
...
@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
return
nullptr
;
return
nullptr
;
}
}
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
)
{
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
)
{
auto
call
=
expr
.
as
<
CallNode
>
();
auto
call
=
expr
.
as
<
CallNode
>
();
ICHECK
(
call
);
ICHECK
(
call
);
ICHECK
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
ICHECK
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
...
@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
bool
RegionOp
::
IsFullRegion
()
const
{
bool
RegionOp
::
IsFullRegion
()
const
{
for
(
size_t
i
=
0
;
i
<
ranges_
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ranges_
.
size
();
i
++
)
{
if
(
!
is_zero
(
ranges_
[
i
]
->
min
))
return
false
;
if
(
!
is_zero
(
ranges_
[
i
]
->
min
))
if
(
!
StructuralEqual
()(
ranges_
[
i
]
->
extent
,
buffer_
->
shape
[
i
]))
return
false
;
return
false
;
if
(
!
StructuralEqual
()(
ranges_
[
i
]
->
extent
,
buffer_
->
shape
[
i
]))
return
false
;
}
}
return
true
;
return
true
;
}
}
Stmt
Operator
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
Stmt
Operator
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
ICHECK
(
0
)
<<
"Not Implemented Lower method."
;
ICHECK
(
0
)
<<
"Not Implemented Lower method."
;
return
Evaluate
(
0
);
return
Evaluate
(
0
);
}
}
Stmt
Operator
::
Canonialize
(
const
CanonializeArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
return
{};
}
Stmt
Operator
::
Canonialize
(
const
CanonializeArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
return
{};
}
LayoutMap
Operator
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
return
{};
}
LayoutMap
Operator
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
return
{};
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/op.h
View file @
549416f7
...
@@ -25,17 +25,19 @@ using namespace tir;
...
@@ -25,17 +25,19 @@ using namespace tir;
using
AddWorkspaceCallback
=
std
::
function
<
PrimExpr
(
int
,
DataType
)
>
;
using
AddWorkspaceCallback
=
std
::
function
<
PrimExpr
(
int
,
DataType
)
>
;
using
LayoutMap
=
Map
<
Buffer
,
Layout
>
;
using
LayoutMap
=
Map
<
Buffer
,
Layout
>
;
using
BufferMap
=
Map
<
Var
,
Buffer
>
;
using
BufferMap
=
Map
<
Var
,
Buffer
>
;
using
OpBuilderFunc
=
TypedPackedFunc
<
void
*
(
Array
<
PrimExpr
>
,
BufferMap
)
>
;
using
OpBuilderFunc
=
TypedPackedFunc
<
void
*
(
Array
<
PrimExpr
>
,
BufferMap
)
>
;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op& Entry::Get() { \
const Op &Entry::Get() { \
static const Op& op = Op::Get("tl." #OpName); \
static const Op &op = Op::Get("tl." #OpName); \
return op; \
return op; \
} \
} \
TVM_REGISTER_OP("tl." #OpName) \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>( \
.set_attr<OpBuilderFunc>("TLOpBuilder", \
"TLOpBuilder", [](Array<PrimExpr> a, BufferMap b) { return (void*)(new Entry(a, b)); })
[](Array<PrimExpr> a, BufferMap b) { \
return (void *)(new Entry(a, b)); \
})
enum
class
InferLevel
{
enum
class
InferLevel
{
kFree
=
0
,
kFree
=
0
,
...
@@ -64,35 +66,36 @@ struct CanonializeArgs {
...
@@ -64,35 +66,36 @@ struct CanonializeArgs {
};
};
class
Operator
{
class
Operator
{
public:
public:
virtual
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
;
virtual
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
;
virtual
Stmt
Canonialize
(
const
CanonializeArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
;
virtual
Stmt
Canonialize
(
const
CanonializeArgs
&
T
,
virtual
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
);
arith
::
Analyzer
*
analyzer
)
const
;
virtual
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
);
virtual
~
Operator
()
=
default
;
virtual
~
Operator
()
=
default
;
};
};
class
RegionOp
:
public
Operator
{
class
RegionOp
:
public
Operator
{
public:
public:
RegionOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
RegionOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
static
const
Op
&
Get
();
const
Buffer
&
GetBuffer
()
const
{
return
buffer_
;
}
const
Buffer
&
GetBuffer
()
const
{
return
buffer_
;
}
const
Array
<
Range
>
&
GetRanges
()
const
{
return
ranges_
;
}
const
Array
<
Range
>
&
GetRanges
()
const
{
return
ranges_
;
}
int
GetAccessMask
()
const
{
return
access_mask_
;
}
int
GetAccessMask
()
const
{
return
access_mask_
;
}
bool
IsFullRegion
()
const
;
bool
IsFullRegion
()
const
;
private:
private:
Buffer
buffer_
;
Buffer
buffer_
;
Array
<
Range
>
ranges_
;
Array
<
Range
>
ranges_
;
int
access_mask_
;
int
access_mask_
;
};
};
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
);
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
);
std
::
unique_ptr
<
Operator
>
ParseOperator
(
Call
call
,
BufferMap
vmap
);
std
::
unique_ptr
<
Operator
>
ParseOperator
(
Call
call
,
BufferMap
vmap
);
std
::
unique_ptr
<
Operator
>
ParseOperator
(
Stmt
stmt
,
BufferMap
vmap
);
std
::
unique_ptr
<
Operator
>
ParseOperator
(
Stmt
stmt
,
BufferMap
vmap
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif
// TVM_TL_OP_OP_H_
#endif // TVM_TL_OP_OP_H_
src/op/parallel.cc
View file @
549416f7
...
@@ -39,21 +39,22 @@ using namespace tir;
...
@@ -39,21 +39,22 @@ using namespace tir;
namespace
attr
{
namespace
attr
{
/*! \brief Mark that how the loop is vectorized. */
/*! \brief Mark that how the loop is vectorized. */
constexpr
const
char
*
coalesced_width
=
"coalesced_width"
;
constexpr
const
char
*
coalesced_width
=
"coalesced_width"
;
}
}
// namespace attr
class
IfBufferRemapLoopGenerator
:
public
StmtExprMutator
{
class
IfBufferRemapLoopGenerator
:
public
StmtExprMutator
{
public:
public:
static
For
run
(
Stmt
stmt
,
Map
<
Buffer
,
Buffer
>
buffer_remap
,
static
For
run
(
Stmt
stmt
,
Map
<
Buffer
,
Buffer
>
buffer_remap
,
Map
<
Buffer
,
Layout
>
layout_map
)
{
Map
<
Buffer
,
Layout
>
layout_map
)
{
IfBufferRemapLoopGenerator
generator
(
buffer_remap
,
layout_map
);
IfBufferRemapLoopGenerator
generator
(
buffer_remap
,
layout_map
);
return
Downcast
<
For
>
(
generator
(
std
::
move
(
stmt
)));
return
Downcast
<
For
>
(
generator
(
std
::
move
(
stmt
)));
}
}
private:
private:
IfBufferRemapLoopGenerator
(
Map
<
Buffer
,
Buffer
>
buffer_remap
,
Map
<
Buffer
,
Layout
>
layout_map
)
IfBufferRemapLoopGenerator
(
Map
<
Buffer
,
Buffer
>
buffer_remap
,
Map
<
Buffer
,
Layout
>
layout_map
)
:
buffer_remap_
(
buffer_remap
),
layout_map_
(
layout_map
)
{}
:
buffer_remap_
(
buffer_remap
),
layout_map_
(
layout_map
)
{}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
Downcast
<
BufferLoad
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
auto
load
=
Downcast
<
BufferLoad
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
if
(
buffer_remap_
.
count
(
load
->
buffer
))
{
if
(
buffer_remap_
.
count
(
load
->
buffer
))
{
...
@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
...
@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
return
load
;
return
load
;
}
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
store
=
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
auto
store
=
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
if
(
buffer_remap_
.
count
(
store
->
buffer
))
{
if
(
buffer_remap_
.
count
(
store
->
buffer
))
{
auto
new_indices
=
layout_map_
[
store
->
buffer
]
->
Forward
(
store
->
indices
);
auto
new_indices
=
layout_map_
[
store
->
buffer
]
->
Forward
(
store
->
indices
);
...
@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
...
@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
Map
<
Buffer
,
Layout
>
layout_map_
;
Map
<
Buffer
,
Layout
>
layout_map_
;
};
};
void
ParallelLoopNestVisitor
::
VisitStmt_
(
const
ForNode
*
op
)
{
void
ParallelLoopNestVisitor
::
VisitStmt_
(
const
ForNode
*
op
)
{
ICHECK
(
op
->
kind
==
ForKind
::
kParallel
);
ICHECK
(
op
->
kind
==
ForKind
::
kParallel
);
p
->
loop_vars_
.
push_back
(
IterVar
(
Range
(
op
->
min
,
op
->
extent
),
op
->
loop_var
,
IterVarType
::
kDataPar
));
p
->
loop_vars_
.
push_back
(
IterVar
(
Range
(
op
->
min
,
op
->
extent
),
op
->
loop_var
,
IterVarType
::
kDataPar
));
p
->
analyzer_
.
Bind
(
op
->
loop_var
,
Range
::
FromMinExtent
(
op
->
min
,
op
->
extent
));
p
->
analyzer_
.
Bind
(
op
->
loop_var
,
Range
::
FromMinExtent
(
op
->
min
,
op
->
extent
));
StmtExprVisitor
::
VisitStmt_
(
op
);
StmtExprVisitor
::
VisitStmt_
(
op
);
}
}
void
ParallelLoopNestVisitor
::
VisitStmt_
(
const
BufferStoreNode
*
op
)
{
void
ParallelLoopNestVisitor
::
VisitStmt_
(
const
BufferStoreNode
*
op
)
{
if
(
op
->
buffer
.
scope
()
==
"local.fragment"
)
{
if
(
op
->
buffer
.
scope
()
==
"local.fragment"
)
{
if
(
p
->
indice_map_
.
find
(
op
->
buffer
)
!=
p
->
indice_map_
.
end
())
{
if
(
p
->
indice_map_
.
find
(
op
->
buffer
)
!=
p
->
indice_map_
.
end
())
{
ICHECK
(
StructuralEqual
()(
p
->
indice_map_
.
at
(
op
->
buffer
),
op
->
indices
))
ICHECK
(
StructuralEqual
()(
p
->
indice_map_
.
at
(
op
->
buffer
),
op
->
indices
))
<<
op
->
buffer
<<
": "
<<
op
->
indices
<<
" and "
<<
p
->
indice_map_
.
at
(
op
->
buffer
);
<<
op
->
buffer
<<
": "
<<
op
->
indices
<<
" and "
<<
p
->
indice_map_
.
at
(
op
->
buffer
);
}
else
{
}
else
{
p
->
indice_map_
.
Set
(
op
->
buffer
,
op
->
indices
);
p
->
indice_map_
.
Set
(
op
->
buffer
,
op
->
indices
);
}
}
...
@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
...
@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
StmtExprVisitor
::
VisitStmt_
(
op
);
StmtExprVisitor
::
VisitStmt_
(
op
);
}
}
void
ParallelLoopNestVisitor
::
VisitExpr_
(
const
BufferLoadNode
*
op
)
{
void
ParallelLoopNestVisitor
::
VisitExpr_
(
const
BufferLoadNode
*
op
)
{
if
(
op
->
buffer
.
scope
()
==
"local.fragment"
)
{
if
(
op
->
buffer
.
scope
()
==
"local.fragment"
)
{
if
(
p
->
indice_map_
.
find
(
op
->
buffer
)
!=
p
->
indice_map_
.
end
())
{
if
(
p
->
indice_map_
.
find
(
op
->
buffer
)
!=
p
->
indice_map_
.
end
())
{
ICHECK
(
StructuralEqual
()(
p
->
indice_map_
.
at
(
op
->
buffer
),
op
->
indices
))
ICHECK
(
StructuralEqual
()(
p
->
indice_map_
.
at
(
op
->
buffer
),
op
->
indices
))
<<
op
->
buffer
<<
": "
<<
op
->
indices
<<
" and "
<<
p
->
indice_map_
.
at
(
op
->
buffer
);
<<
op
->
buffer
<<
": "
<<
op
->
indices
<<
" and "
<<
p
->
indice_map_
.
at
(
op
->
buffer
);
}
else
{
}
else
{
p
->
indice_map_
.
Set
(
op
->
buffer
,
op
->
indices
);
p
->
indice_map_
.
Set
(
op
->
buffer
,
op
->
indices
);
}
}
...
@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
...
@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
ParallelOp
::
ParallelOp
(
For
root
)
:
root_
(
root
),
V
(
this
)
{
V
.
VisitStmt
(
root
);
}
ParallelOp
::
ParallelOp
(
For
root
)
:
root_
(
root
),
V
(
this
)
{
V
.
VisitStmt
(
root
);
}
bool
ParallelOp
::
IsCommonAccessIndice
(
const
Buffer
&
buffer
)
const
{
bool
ParallelOp
::
IsCommonAccessIndice
(
const
Buffer
&
buffer
)
const
{
auto
common_indice
=
loop_vars_
.
Map
([](
const
auto
&
iv
)
{
return
iv
->
var
;
});
auto
common_indice
=
loop_vars_
.
Map
([](
const
auto
&
iv
)
{
return
iv
->
var
;
});
return
StructuralEqual
()(
indice_map_
[
buffer
],
common_indice
);
return
StructuralEqual
()(
indice_map_
[
buffer
],
common_indice
);
}
}
LayoutMap
ParallelOp
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
LayoutMap
ParallelOp
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
if
(
loop_layout_
.
defined
())
return
{};
if
(
loop_layout_
.
defined
())
if
(
level
==
InferLevel
::
kStrict
)
return
{};
return
{};
if
(
level
==
InferLevel
::
kStrict
)
return
{};
// Step 1: try to infer loop's partition from a source fragment
// Step 1: try to infer loop's partition from a source fragment
Buffer
source_buffer
,
read_source_buffer
;
Buffer
source_buffer
,
read_source_buffer
;
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
if
(
T
.
layout_map
.
count
(
buffer
))
{
if
(
T
.
layout_map
.
count
(
buffer
))
{
auto
frag
=
T
.
layout_map
[
buffer
].
as
<
Fragment
>
().
value
();
auto
frag
=
T
.
layout_map
[
buffer
].
as
<
Fragment
>
().
value
();
if
(
buffer_is_write_
.
count
(
buffer
))
if
(
buffer_is_write_
.
count
(
buffer
))
...
@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
...
@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
read_source_buffer
=
buffer
;
read_source_buffer
=
buffer
;
}
}
}
}
auto
compute_loop_layout_from_buffer
=
[
&
](
const
Buffer
&
buffer
)
{
auto
compute_loop_layout_from_buffer
=
[
&
](
const
Buffer
&
buffer
)
{
Fragment
src_layout
=
T
.
layout_map
[
buffer
].
as
<
Fragment
>
().
value
();
Fragment
src_layout
=
T
.
layout_map
[
buffer
].
as
<
Fragment
>
().
value
();
if
(
IsCommonAccessIndice
(
buffer
))
{
if
(
IsCommonAccessIndice
(
buffer
))
{
return
src_layout
;
return
src_layout
;
}
else
{
}
else
{
Var
rep
;
Var
rep
;
auto
rep_iter
=
IterVar
({
0
,
src_layout
->
ReplicateExtent
()},
rep
,
IterVarType
::
kDataPar
);
auto
rep_iter
=
IterVar
({
0
,
src_layout
->
ReplicateExtent
()},
rep
,
PrimExpr
loop_var_to_thread
=
src_layout
->
ForwardThread
(
indice_map_
[
buffer
],
rep
);
IterVarType
::
kDataPar
);
PrimExpr
loop_var_to_thread
=
src_layout
->
ForwardThread
(
indice_map_
[
buffer
],
rep
);
return
Fragment
(
loop_vars_
,
{},
loop_var_to_thread
,
rep_iter
);
return
Fragment
(
loop_vars_
,
{},
loop_var_to_thread
,
rep_iter
);
}
}
};
};
...
@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
...
@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if
(
read_source_buffer
.
defined
())
{
if
(
read_source_buffer
.
defined
())
{
loop_layout_
=
compute_loop_layout_from_buffer
(
read_source_buffer
);
loop_layout_
=
compute_loop_layout_from_buffer
(
read_source_buffer
);
// Loop don't need to be replicated.
// Loop don't need to be replicated.
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
()))
loop_layout_
=
loop_layout_
->
DeReplicate
();
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
()))
loop_layout_
=
loop_layout_
->
DeReplicate
();
// if still has replication, add a condition
// if still has replication, add a condition
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
()))
{
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
()))
{
auto
inv
=
loop_layout_
->
Inverse
();
auto
inv
=
loop_layout_
->
Inverse
();
Array
<
PrimExpr
>
fwd
;
Array
<
PrimExpr
>
fwd
;
for
(
size_t
i
=
0
;
i
<
loop_layout_
->
OutputDim
();
i
++
)
fwd
.
push_back
(
0
);
for
(
size_t
i
=
0
;
i
<
loop_layout_
->
OutputDim
();
i
++
)
fwd
.
push_back
(
0
);
fwd
.
push_back
(
InputPlaceholder
(
0
));
fwd
.
push_back
(
InputPlaceholder
(
0
));
auto
rep
=
inv
->
Forward
(
fwd
).
back
();
auto
rep
=
inv
->
Forward
(
fwd
).
back
();
AddPredicate
(
EQ
(
rep
,
0
));
AddPredicate
(
EQ
(
rep
,
0
));
...
@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
...
@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
}
else
{
}
else
{
// Vectorize Size must be aware of the buffer_remap
// Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout
// As the pass will do post processing to the layout
auto
maybe_remapped_root_
=
IfBufferRemapLoopGenerator
::
run
(
root_
,
T
.
buffer_remap
,
T
.
layout_map
);
auto
maybe_remapped_root_
=
IfBufferRemapLoopGenerator
::
run
(
root_
,
T
.
buffer_remap
,
T
.
layout_map
);
int
vector_size
=
GetVectorizeSize
(
maybe_remapped_root_
);
int
vector_size
=
GetVectorizeSize
(
maybe_remapped_root_
);
// Check if coalesced_width is defined
// Check if coalesced_width is defined
if
(
auto
coalesced_width
=
root_
->
annotations
.
Get
(
tl
::
attr
::
coalesced_width
))
{
if
(
auto
coalesced_width
=
if
(
const
auto
*
imm
=
coalesced_width
.
as
<
IntImmNode
>
())
{
root_
->
annotations
.
Get
(
tl
::
attr
::
coalesced_width
))
{
if
(
const
auto
*
imm
=
coalesced_width
.
as
<
IntImmNode
>
())
{
int
expected
=
imm
->
value
;
int
expected
=
imm
->
value
;
// Verify that vector_size is divisible by expected
// Verify that vector_size is divisible by expected
if
(
vector_size
%
expected
!=
0
)
{
if
(
vector_size
%
expected
!=
0
)
{
LOG
(
FATAL
)
<<
"Vector size "
<<
vector_size
<<
" is not divisible by coalesced width "
LOG
(
FATAL
)
<<
"Vector size "
<<
vector_size
<<
expected
;
<<
" is not divisible by coalesced width "
<<
expected
;
}
}
vector_size
=
expected
;
vector_size
=
expected
;
}
else
{
}
else
{
...
@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
...
@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
loop_layout_
=
PlanLoopPartition
(
root_
,
T
.
block_size
,
vector_size
);
loop_layout_
=
PlanLoopPartition
(
root_
,
T
.
block_size
,
vector_size
);
}
}
PrimExpr
loop_thread_extent
=
loop_layout_
->
ThreadExtent
();
PrimExpr
loop_thread_extent
=
loop_layout_
->
ThreadExtent
();
if
(
!
analyzer_
.
CanProveEqual
(
loop_thread_extent
,
static_cast
<
int
>
(
T
.
block_size
)))
if
(
!
analyzer_
.
CanProveEqual
(
loop_thread_extent
,
static_cast
<
int
>
(
T
.
block_size
)))
AddPredicate
(
LT
(
InputPlaceholder
(
0
),
loop_thread_extent
));
AddPredicate
(
LT
(
InputPlaceholder
(
0
),
loop_thread_extent
));
}
else
{
}
else
{
return
{};
return
{};
}
}
// Step 2: Check that the loop's partition can correctly align with all source fragment
// Step 2: Check that the loop's partition can correctly align with all source
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
// fragment
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
if
(
T
.
layout_map
.
count
(
buffer
))
{
if
(
T
.
layout_map
.
count
(
buffer
))
{
auto
fragment
=
T
.
layout_map
[
buffer
].
as
<
Fragment
>
().
value
();
auto
fragment
=
T
.
layout_map
[
buffer
].
as
<
Fragment
>
().
value
();
// TODO: Add thread checks for replicated cases
// TODO: Add thread checks for replicated cases
// need to wildcard match the rhs with lhs
// need to wildcard match the rhs with lhs
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
())
||
!
is_one
(
fragment
->
ReplicateExtent
()))
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
())
||
!
is_one
(
fragment
->
ReplicateExtent
()))
continue
;
continue
;
auto
vars
=
loop_vars_
.
Map
([](
const
IterVar
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
});
auto
vars
=
loop_vars_
.
Map
([](
const
IterVar
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
});
auto
lhs
=
loop_layout_
->
ForwardThread
(
vars
,
NullOpt
);
auto
lhs
=
loop_layout_
->
ForwardThread
(
vars
,
NullOpt
);
auto
rhs
=
fragment
->
ForwardThread
(
indice_map_
[
buffer
],
NullOpt
);
auto
rhs
=
fragment
->
ForwardThread
(
indice_map_
[
buffer
],
NullOpt
);
auto
diff
=
analyzer_
.
Simplify
(
lhs
-
rhs
);
auto
diff
=
analyzer_
.
Simplify
(
lhs
-
rhs
);
ICHECK
(
is_zero
(
diff
))
<<
"Layout infer conflict for "
<<
buffer
<<
" "
<<
source_buffer
ICHECK
(
is_zero
(
diff
))
<<
"
\n
LHS = "
<<
lhs
<<
"
\n
RHS = "
<<
rhs
;
<<
"Layout infer conflict for "
<<
buffer
<<
" "
<<
source_buffer
<<
"
\n
LHS = "
<<
lhs
<<
"
\n
RHS = "
<<
rhs
;
}
}
}
}
// Step 3: Infer other fragment's layout from the loop's partition
// Step 3: Infer other fragment's layout from the loop's partition
LayoutMap
results
;
LayoutMap
results
;
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
if
(
!
T
.
layout_map
.
count
(
buffer
))
results
.
Set
(
buffer
,
CompleteBufferFragment
(
buffer
));
if
(
!
T
.
layout_map
.
count
(
buffer
))
results
.
Set
(
buffer
,
CompleteBufferFragment
(
buffer
));
}
}
return
results
;
return
results
;
}
}
...
@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
...
@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
}
}
}
}
Fragment
ParallelOp
::
CompleteBufferFragment
(
const
Buffer
&
buffer
)
{
Fragment
ParallelOp
::
CompleteBufferFragment
(
const
Buffer
&
buffer
)
{
ICHECK
(
loop_layout_
.
defined
());
ICHECK
(
loop_layout_
.
defined
());
if
(
IsCommonAccessIndice
(
buffer
))
return
loop_layout_
;
if
(
IsCommonAccessIndice
(
buffer
))
return
loop_layout_
;
PrimExpr
rep_b
=
PrimExpr
rep_b
=
MakeFlattenedExpression
(
MakeFlattenedExpression
(
DivideUnusedIterators
(
indice_map_
[
buffer
],
loop_vars_
,
&
analyzer_
));
DivideUnusedIterators
(
indice_map_
[
buffer
],
loop_vars_
,
&
analyzer_
));
auto
bijective_indice
=
indice_map_
[
buffer
];
auto
bijective_indice
=
indice_map_
[
buffer
];
bijective_indice
.
push_back
(
rep_b
);
bijective_indice
.
push_back
(
rep_b
);
Layout
ind_inv
=
Layout
(
loop_vars_
,
bijective_indice
)
->
Inverse
();
Layout
ind_inv
=
Layout
(
loop_vars_
,
bijective_indice
)
->
Inverse
();
PrimExpr
indice_rep_extent
=
ind_inv
->
InputShape
().
back
();
// this is the size of rep_b
PrimExpr
indice_rep_extent
=
ind_inv
->
InputShape
().
back
();
// this is the size of rep_b
PrimExpr
loop_rep_extent
=
loop_layout_
->
ReplicateExtent
();
PrimExpr
loop_rep_extent
=
loop_layout_
->
ReplicateExtent
();
PrimExpr
dest_buffer_rep_extent
=
indice_rep_extent
*
loop_rep_extent
;
PrimExpr
dest_buffer_rep_extent
=
indice_rep_extent
*
loop_rep_extent
;
...
@@ -242,11 +262,12 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
...
@@ -242,11 +262,12 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
}
}
fwd
.
push_back
(
FloorMod
(
ReplicationPlaceholder
(),
indice_rep_extent
));
fwd
.
push_back
(
FloorMod
(
ReplicationPlaceholder
(),
indice_rep_extent
));
PrimExpr
thd_b
=
loop_layout_
->
ForwardThread
(
PrimExpr
thd_b
=
loop_layout_
->
ForwardThread
(
ind_inv
->
Forward
(
fwd
),
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
ind_inv
->
Forward
(
fwd
),
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
return
Fragment
(
buffer
->
shape
,
{},
thd_b
,
dest_buffer_rep_extent
,
NullOpt
)
return
Fragment
(
buffer
->
shape
,
{},
thd_b
,
dest_buffer_rep_extent
,
NullOpt
)
->
CondenseReplicateVar
();
->
CondenseReplicateVar
();
}
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/parallel.h
View file @
549416f7
...
@@ -23,30 +23,30 @@ using namespace tir;
...
@@ -23,30 +23,30 @@ using namespace tir;
class
ParallelOp
;
class
ParallelOp
;
class
ParallelLoopNestVisitor
:
public
StmtExprVisitor
{
class
ParallelLoopNestVisitor
:
public
StmtExprVisitor
{
private:
private:
ParallelLoopNestVisitor
(
ParallelOp
*
op
)
:
p
(
op
){};
ParallelLoopNestVisitor
(
ParallelOp
*
op
)
:
p
(
op
){};
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
;
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
;
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
;
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
;
ParallelOp
*
p
;
ParallelOp
*
p
;
friend
class
ParallelOp
;
friend
class
ParallelOp
;
};
};
class
ParallelOp
:
public
Operator
{
class
ParallelOp
:
public
Operator
{
public:
public:
ParallelOp
(
For
root
);
ParallelOp
(
For
root
);
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
Fragment
GetLoopLayout
()
const
{
return
loop_layout_
;
}
Fragment
GetLoopLayout
()
const
{
return
loop_layout_
;
}
For
GetRoot
()
const
{
return
root_
;
}
For
GetRoot
()
const
{
return
root_
;
}
Map
<
Buffer
,
Array
<
PrimExpr
>>
GetIndiceMap
()
const
{
return
indice_map_
;
}
Map
<
Buffer
,
Array
<
PrimExpr
>>
GetIndiceMap
()
const
{
return
indice_map_
;
}
Optional
<
PrimExpr
>
GetPredicate
(
Var
thread_var
)
const
;
Optional
<
PrimExpr
>
GetPredicate
(
Var
thread_var
)
const
;
private:
private:
Fragment
CompleteBufferFragment
(
const
Buffer
&
buffer
);
Fragment
CompleteBufferFragment
(
const
Buffer
&
buffer
);
bool
IsCommonAccessIndice
(
const
Buffer
&
buffer
)
const
;
bool
IsCommonAccessIndice
(
const
Buffer
&
buffer
)
const
;
void
AddPredicate
(
PrimExpr
expr
)
{
void
AddPredicate
(
PrimExpr
expr
)
{
predicate_
=
predicate_
.
defined
()
?
And
(
expr
,
predicate_
.
value
())
:
expr
;
predicate_
=
predicate_
.
defined
()
?
And
(
expr
,
predicate_
.
value
())
:
expr
;
}
}
...
@@ -66,7 +66,7 @@ class ParallelOp : public Operator {
...
@@ -66,7 +66,7 @@ class ParallelOp : public Operator {
friend
class
ParallelLoopNestVisitor
;
friend
class
ParallelLoopNestVisitor
;
};
};
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif
// TVM_TL_OP_PARALLEL_H_
#endif // TVM_TL_OP_PARALLEL_H_
src/op/reduce.cc
View file @
549416f7
...
@@ -41,57 +41,58 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -41,57 +41,58 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
PrimExpr
ReduceOp
::
MakeInitValue
()
const
{
PrimExpr
ReduceOp
::
MakeInitValue
()
const
{
switch
(
type
)
{
switch
(
type
)
{
case
ReduceType
::
kSum
:
case
ReduceType
::
kSum
:
return
make_zero
(
dst
->
dtype
);
return
make_zero
(
dst
->
dtype
);
case
ReduceType
::
kAbsSum
:
case
ReduceType
::
kAbsSum
:
return
make_zero
(
dst
->
dtype
);
return
make_zero
(
dst
->
dtype
);
case
ReduceType
::
kMax
:
case
ReduceType
::
kMax
:
return
make_const
(
dst
->
dtype
,
-
INFINITY
);
return
make_const
(
dst
->
dtype
,
-
INFINITY
);
case
ReduceType
::
kMin
:
case
ReduceType
::
kMin
:
return
make_const
(
dst
->
dtype
,
INFINITY
);
return
make_const
(
dst
->
dtype
,
INFINITY
);
default:
default:
ICHECK
(
0
);
ICHECK
(
0
);
}
}
}
}
PrimExpr
ReduceOp
::
MakeReduce
(
const
PrimExpr
&
a
,
const
PrimExpr
&
b
)
const
{
PrimExpr
ReduceOp
::
MakeReduce
(
const
PrimExpr
&
a
,
const
PrimExpr
&
b
)
const
{
PrimExpr
lhs
=
a
,
rhs
=
b
;
PrimExpr
lhs
=
a
,
rhs
=
b
;
if
(
lhs
->
dtype
!=
rhs
->
dtype
)
{
if
(
lhs
->
dtype
!=
rhs
->
dtype
)
{
rhs
=
Cast
(
lhs
->
dtype
,
rhs
);
rhs
=
Cast
(
lhs
->
dtype
,
rhs
);
}
}
switch
(
type
)
{
switch
(
type
)
{
case
ReduceType
::
kSum
:
case
ReduceType
::
kSum
:
return
lhs
+
rhs
;
return
lhs
+
rhs
;
case
ReduceType
::
kAbsSum
:
case
ReduceType
::
kAbsSum
:
return
lhs
+
Max
(
rhs
,
-
rhs
);
return
lhs
+
Max
(
rhs
,
-
rhs
);
case
ReduceType
::
kMax
:
case
ReduceType
::
kMax
:
return
Max
(
lhs
,
rhs
);
return
Max
(
lhs
,
rhs
);
case
ReduceType
::
kMin
:
case
ReduceType
::
kMin
:
return
Min
(
lhs
,
rhs
);
return
Min
(
lhs
,
rhs
);
default:
default:
ICHECK
(
0
);
ICHECK
(
0
);
return
PrimExpr
(
0
);
return
PrimExpr
(
0
);
}
}
}
}
std
::
string
ReduceOp
::
MakeCodegenReducer
()
const
{
std
::
string
ReduceOp
::
MakeCodegenReducer
()
const
{
switch
(
type
)
{
switch
(
type
)
{
case
ReduceType
::
kSum
:
case
ReduceType
::
kSum
:
return
"tl::SumOp"
;
return
"tl::SumOp"
;
case
ReduceType
::
kAbsSum
:
case
ReduceType
::
kAbsSum
:
return
"tl::SumOp"
;
return
"tl::SumOp"
;
case
ReduceType
::
kMax
:
case
ReduceType
::
kMax
:
return
"tl::MaxOp"
;
return
"tl::MaxOp"
;
case
ReduceType
::
kMin
:
case
ReduceType
::
kMin
:
return
"tl::MinOp"
;
return
"tl::MinOp"
;
default:
default:
ICHECK
(
0
);
ICHECK
(
0
);
return
""
;
return
""
;
}
}
}
}
Stmt
ReduceOp
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
Stmt
ReduceOp
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
ICHECK
(
this
->
src
.
scope
()
==
"local.fragment"
&&
this
->
dst
.
scope
()
==
"local.fragment"
)
ICHECK
(
this
->
src
.
scope
()
==
"local.fragment"
&&
this
->
dst
.
scope
()
==
"local.fragment"
)
<<
"Reduce for shared memory not implemented."
;
<<
"Reduce for shared memory not implemented."
;
auto
src_buffer
=
T
.
buffer_remap
[
this
->
src
];
auto
src_buffer
=
T
.
buffer_remap
[
this
->
src
];
auto
dst_buffer
=
T
.
buffer_remap
[
this
->
dst
];
auto
dst_buffer
=
T
.
buffer_remap
[
this
->
dst
];
...
@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
...
@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Array
<
IterVar
>
dst_vars
;
Array
<
IterVar
>
dst_vars
;
for
(
size_t
i
=
0
;
i
<
dst_layout
->
InputDim
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
dst_layout
->
InputDim
();
i
++
)
{
Var
var
=
Var
(
std
::
string
{
char
(
'i'
+
i
)});
Var
var
=
Var
(
std
::
string
{
char
(
'i'
+
i
)});
dst_vars
.
push_back
(
IterVar
(
Range
(
0
,
dst_layout
->
InputShape
()[
i
]),
var
,
IterVarType
::
kDataPar
));
dst_vars
.
push_back
(
IterVar
(
Range
(
0
,
dst_layout
->
InputShape
()[
i
]),
var
,
IterVarType
::
kDataPar
));
}
}
Array
<
IterVar
>
src_vars
=
dst_vars
;
Array
<
IterVar
>
src_vars
=
dst_vars
;
src_vars
.
insert
(
src_vars
.
begin
()
+
this
->
dim
,
{
Range
(
0
,
src_layout
->
InputShape
()[
this
->
dim
]),
src_vars
.
insert
(
src_vars
.
begin
()
+
this
->
dim
,
Var
(
"rv"
),
IterVarType
::
kDataPar
});
{
Range
(
0
,
src_layout
->
InputShape
()[
this
->
dim
]),
Var
(
"rv"
),
Array
<
PrimExpr
>
src_indices
=
IterVarType
::
kDataPar
});
src_layout
->
Forward
(
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
Array
<
PrimExpr
>
src_indices
=
src_layout
->
Forward
(
Array
<
PrimExpr
>
dst_indices
=
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
dst_layout
->
Forward
(
dst_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
Array
<
PrimExpr
>
dst_indices
=
dst_layout
->
Forward
(
dst_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
Array
<
Stmt
>
stmts
;
Array
<
Stmt
>
stmts
;
// make reduce-init stmt
// make reduce-init stmt
if
(
this
->
clear
)
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
this
->
MakeInitValue
(),
dst_indices
));
if
(
this
->
clear
)
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
this
->
MakeInitValue
(),
dst_indices
));
// make thread-local reduce
// make thread-local reduce
Array
<
PrimExpr
>
src_indice_compressed
;
Array
<
PrimExpr
>
src_indice_compressed
;
...
@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
...
@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
for
(
size_t
i
=
0
;
i
<
src_layout
->
OutputDim
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
src_layout
->
OutputDim
();
i
++
)
{
PrimExpr
expr
;
PrimExpr
expr
;
IterVar
var
;
IterVar
var
;
std
::
tie
(
expr
,
var
)
=
std
::
tie
(
expr
,
var
)
=
CompressIterator
(
src_indices
[
i
],
src_vars
,
CompressIterator
(
src_indices
[
i
],
src_vars
,
src_vars
[
this
->
dim
]
->
var
,
analyzer
);
src_vars
[
this
->
dim
]
->
var
,
analyzer
);
src_indice_compressed
.
push_back
(
expr
);
src_indice_compressed
.
push_back
(
expr
);
src_var_compressed
.
push_back
(
var
);
src_var_compressed
.
push_back
(
var
);
}
}
Stmt
reduce_local
=
BufferStore
(
dst_buffer
,
Stmt
reduce_local
=
BufferStore
(
this
->
MakeReduce
(
BufferLoad
(
dst_buffer
,
dst_indices
),
dst_buffer
,
BufferLoad
(
src_buffer
,
src_indice_compressed
)),
this
->
MakeReduce
(
BufferLoad
(
dst_buffer
,
dst_indices
),
dst_indices
);
BufferLoad
(
src_buffer
,
src_indice_compressed
)),
dst_indices
);
for
(
int
i
=
src_layout
->
OutputDim
()
-
1
;
i
>=
0
;
i
--
)
{
for
(
int
i
=
src_layout
->
OutputDim
()
-
1
;
i
>=
0
;
i
--
)
{
reduce_local
=
reduce_local
=
For
(
src_var_compressed
[
i
]
->
var
,
0
,
src_var_compressed
[
i
]
->
dom
->
extent
,
ForKind
::
kUnrolled
,
For
(
src_var_compressed
[
i
]
->
var
,
0
,
src_var_compressed
[
i
]
->
dom
->
extent
,
reduce_local
,
NullOpt
,
{{
tir
::
attr
::
pragma_unroll_explicit
,
Bool
(
false
)}});
ForKind
::
kUnrolled
,
reduce_local
,
NullOpt
,
{{
tir
::
attr
::
pragma_unroll_explicit
,
Bool
(
false
)}});
}
}
stmts
.
push_back
(
reduce_local
);
stmts
.
push_back
(
reduce_local
);
// make inter-thread reduce
// make inter-thread reduce
PrimExpr
src_thread
=
PrimExpr
src_thread
=
src_layout
->
ForwardThread
(
src_layout
->
ForwardThread
(
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}),
{});
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}),
{});
auto
iter_sum
=
arith
::
NormalizeToIterSum
(
src_thread
,
ToVMap
(
src_vars
),
analyzer
);
auto
iter_sum
=
for
(
const
auto
&
iter_split
:
iter_sum
->
args
)
{
arith
::
NormalizeToIterSum
(
src_thread
,
ToVMap
(
src_vars
),
analyzer
);
for
(
const
auto
&
iter_split
:
iter_sum
->
args
)
{
auto
mark
=
iter_split
->
source
->
source
.
as
<
Var
>
();
auto
mark
=
iter_split
->
source
->
source
.
as
<
Var
>
();
ICHECK
(
mark
.
defined
());
ICHECK
(
mark
.
defined
());
if
(
mark
.
value
().
same_as
(
src_vars
[
this
->
dim
]
->
var
))
{
if
(
mark
.
value
().
same_as
(
src_vars
[
this
->
dim
]
->
var
))
{
auto
scale
=
as_const_int
(
iter_split
->
scale
);
auto
scale
=
as_const_int
(
iter_split
->
scale
);
auto
extent
=
as_const_int
(
iter_split
->
extent
);
auto
extent
=
as_const_int
(
iter_split
->
extent
);
ICHECK
(
scale
!=
nullptr
&&
extent
!=
nullptr
);
ICHECK
(
scale
!=
nullptr
&&
extent
!=
nullptr
);
if
(
*
extent
==
1
)
continue
;
if
(
*
extent
==
1
)
continue
;
int
reducing_threads
=
(
*
extent
)
*
(
*
scale
);
int
reducing_threads
=
(
*
extent
)
*
(
*
scale
);
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"tl::AllReduce<"
<<
this
->
MakeCodegenReducer
()
<<
", "
<<
reducing_threads
<<
", "
ss
<<
"tl::AllReduce<"
<<
this
->
MakeCodegenReducer
()
<<
", "
<<
(
*
scale
)
<<
">::run"
;
<<
reducing_threads
<<
", "
<<
(
*
scale
)
<<
">::run"
;
Array
<
PrimExpr
>
thread_reduce_args
=
{
StringImm
(
ss
.
str
()),
Array
<
PrimExpr
>
thread_reduce_args
=
{
BufferLoad
(
dst_buffer
,
dst_indices
)};
StringImm
(
ss
.
str
()),
BufferLoad
(
dst_buffer
,
dst_indices
)};
if
(
reducing_threads
>=
32
)
{
if
(
reducing_threads
>=
32
)
{
PrimExpr
workspace
=
T
.
AddWorkspace
(
T
.
block_size
,
dst_buffer
->
dtype
);
PrimExpr
workspace
=
T
.
AddWorkspace
(
T
.
block_size
,
dst_buffer
->
dtype
);
thread_reduce_args
.
push_back
(
workspace
);
thread_reduce_args
.
push_back
(
workspace
);
}
}
auto
call
=
Call
(
dst_buffer
->
dtype
,
builtin
::
call_extern
(),
thread_reduce_args
);
auto
call
=
Call
(
dst_buffer
->
dtype
,
builtin
::
call_extern
(),
thread_reduce_args
);
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
call
,
dst_indices
));
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
call
,
dst_indices
));
}
}
}
}
...
@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
...
@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
// make the outer spatial loop
// make the outer spatial loop
Stmt
body
=
stmts
.
size
()
>
1
?
SeqStmt
(
stmts
)
:
stmts
[
0
];
Stmt
body
=
stmts
.
size
()
>
1
?
SeqStmt
(
stmts
)
:
stmts
[
0
];
for
(
int
i
=
dst_layout
->
InputDim
()
-
1
;
i
>=
0
;
i
--
)
{
for
(
int
i
=
dst_layout
->
InputDim
()
-
1
;
i
>=
0
;
i
--
)
{
body
=
For
(
dst_vars
[
i
]
->
var
,
0
,
dst_vars
[
i
]
->
dom
->
extent
,
ForKind
::
kParallel
,
body
);
body
=
For
(
dst_vars
[
i
]
->
var
,
0
,
dst_vars
[
i
]
->
dom
->
extent
,
ForKind
::
kParallel
,
body
);
}
}
body
=
PartitionLoop
(
Downcast
<
For
>
(
body
),
T
.
thread_var
,
analyzer
,
dst_layout
);
body
=
PartitionLoop
(
Downcast
<
For
>
(
body
),
T
.
thread_var
,
analyzer
,
dst_layout
);
return
body
;
return
body
;
}
}
LayoutMap
ReduceOp
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
LayoutMap
ReduceOp
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
if
(
level
>=
InferLevel
::
kStrict
)
return
{};
if
(
level
>=
InferLevel
::
kStrict
)
return
{};
if
(
src
.
scope
()
==
"local.fragment"
&&
dst
.
scope
()
==
"local.fragment"
&&
if
(
src
.
scope
()
==
"local.fragment"
&&
dst
.
scope
()
==
"local.fragment"
&&
T
.
layout_map
.
count
(
src
)
&&
!
T
.
layout_map
.
count
(
dst
))
{
T
.
layout_map
.
count
(
src
)
&&
!
T
.
layout_map
.
count
(
dst
))
{
auto
src_layout
=
T
.
layout_map
[
src
].
as
<
Fragment
>
().
value
();
auto
src_layout
=
T
.
layout_map
[
src
].
as
<
Fragment
>
().
value
();
...
@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
...
@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
fwd
.
push_back
(
InputPlaceholder
(
i
-
1
));
fwd
.
push_back
(
InputPlaceholder
(
i
-
1
));
}
}
}
}
auto
thd
=
auto
thd
=
src_layout
->
ForwardThread
(
src_layout
->
ForwardThread
(
fwd
,
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
fwd
,
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
Fragment
dst_layout
=
Fragment
dst_layout
=
Fragment
(
dst
->
shape
,
{},
thd
,
dest_buffer_rep_extent
,
NullOpt
)
->
CondenseReplicateVar
();
Fragment
(
dst
->
shape
,
{},
thd
,
dest_buffer_rep_extent
,
NullOpt
)
->
CondenseReplicateVar
();
return
{{
dst
,
dst_layout
}};
return
{{
dst
,
dst_layout
}};
}
}
return
{};
return
{};
...
@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
...
@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
TIR_REGISTER_TL_OP
(
ReduceOp
,
reduce
)
TIR_REGISTER_TL_OP
(
ReduceOp
,
reduce
)
.
set_num_inputs
(
4
)
.
set_num_inputs
(
4
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
\ No newline at end of file
\ No newline at end of file
src/op/reduce.h
View file @
549416f7
...
@@ -18,13 +18,13 @@ namespace tl {
...
@@ -18,13 +18,13 @@ namespace tl {
using
namespace
tir
;
using
namespace
tir
;
class
ReduceOp
:
public
Operator
{
class
ReduceOp
:
public
Operator
{
public:
public:
ReduceOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
ReduceOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
final
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
final
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
static
const
Op
&
Get
();
static
const
Op
&
Get
();
private:
private:
tir
::
Buffer
src
,
dst
;
tir
::
Buffer
src
,
dst
;
int
dim
;
int
dim
;
enum
class
ReduceType
{
enum
class
ReduceType
{
...
@@ -36,11 +36,11 @@ class ReduceOp : public Operator {
...
@@ -36,11 +36,11 @@ class ReduceOp : public Operator {
bool
clear
;
bool
clear
;
PrimExpr
MakeInitValue
()
const
;
PrimExpr
MakeInitValue
()
const
;
PrimExpr
MakeReduce
(
const
PrimExpr
&
a
,
const
PrimExpr
&
b
)
const
;
PrimExpr
MakeReduce
(
const
PrimExpr
&
a
,
const
PrimExpr
&
b
)
const
;
std
::
string
MakeCodegenReducer
()
const
;
std
::
string
MakeCodegenReducer
()
const
;
};
};
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif // TVM_TL_OP_REDUCE_H_
#endif // TVM_TL_OP_REDUCE_H_
\ No newline at end of file
\ No newline at end of file
src/runtime/runtime.cc
View file @
549416f7
...
@@ -17,12 +17,12 @@ namespace tl {
...
@@ -17,12 +17,12 @@ namespace tl {
using
namespace
runtime
;
using
namespace
runtime
;
template
<
typename
T
>
template
<
typename
T
>
static
std
::
string
ArrayToStr
(
const
T
*
ptr
,
size_t
n
)
{
static
std
::
string
ArrayToStr
(
const
T
*
ptr
,
size_t
n
)
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"["
;
ss
<<
"["
;
for
(
size_t
i
=
0
;
i
<
n
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
n
;
i
++
)
{
if
(
i
>
0
)
ss
<<
", "
;
if
(
i
>
0
)
ss
<<
", "
;
ss
<<
ptr
[
i
];
ss
<<
ptr
[
i
];
}
}
ss
<<
"]"
;
ss
<<
"]"
;
...
@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) {
...
@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) {
}
}
struct
TensorMapArgs
{
struct
TensorMapArgs
{
CUtensorMap
*
map
;
CUtensorMap
*
map
;
CUtensorMapDataType
type
;
CUtensorMapDataType
type
;
cuuint32_t
tensorRank
;
cuuint32_t
tensorRank
;
void
*
globalAddress
;
void
*
globalAddress
;
cuuint64_t
globalDim
[
5
],
globalStride
[
5
];
cuuint64_t
globalDim
[
5
],
globalStride
[
5
];
cuuint32_t
boxDim
[
5
],
elementStrides
[
5
];
cuuint32_t
boxDim
[
5
],
elementStrides
[
5
];
CUtensorMapInterleave
interleave
;
CUtensorMapInterleave
interleave
;
...
@@ -45,8 +45,9 @@ struct TensorMapArgs {
...
@@ -45,8 +45,9 @@ struct TensorMapArgs {
TensorMapArgs
T
;
TensorMapArgs
T
;
int
idx
=
0
;
int
idx
=
0
;
ICHECK
(
args
.
num_args
>=
8
);
ICHECK
(
args
.
num_args
>=
8
);
T
.
map
=
reinterpret_cast
<
CUtensorMap
*>
(
static_cast
<
void
*>
(
args
[
idx
++
]));
T
.
map
=
reinterpret_cast
<
CUtensorMap
*>
(
static_cast
<
void
*>
(
args
[
idx
++
]));
T
.
type
=
static_cast
<
CUtensorMapDataType
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
type
=
static_cast
<
CUtensorMapDataType
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
tensorRank
=
static_cast
<
cuuint32_t
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
tensorRank
=
static_cast
<
cuuint32_t
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
globalAddress
=
args
[
idx
++
];
T
.
globalAddress
=
args
[
idx
++
];
ICHECK
(
T
.
tensorRank
>=
1
&&
T
.
tensorRank
<=
5
);
ICHECK
(
T
.
tensorRank
>=
1
&&
T
.
tensorRank
<=
5
);
...
@@ -63,10 +64,14 @@ struct TensorMapArgs {
...
@@ -63,10 +64,14 @@ struct TensorMapArgs {
for
(
size_t
i
=
0
;
i
<
T
.
tensorRank
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
T
.
tensorRank
;
i
++
)
{
T
.
elementStrides
[
i
]
=
static_cast
<
cuuint64_t
>
(
args
[
idx
++
]);
T
.
elementStrides
[
i
]
=
static_cast
<
cuuint64_t
>
(
args
[
idx
++
]);
}
}
T
.
interleave
=
static_cast
<
CUtensorMapInterleave
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
interleave
=
T
.
swizzle
=
static_cast
<
CUtensorMapSwizzle
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
static_cast
<
CUtensorMapInterleave
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
l2Promotion
=
static_cast
<
CUtensorMapL2promotion
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
swizzle
=
T
.
oobFill
=
static_cast
<
CUtensorMapFloatOOBfill
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
static_cast
<
CUtensorMapSwizzle
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
l2Promotion
=
static_cast
<
CUtensorMapL2promotion
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
oobFill
=
static_cast
<
CUtensorMapFloatOOBfill
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
return
T
;
return
T
;
}
}
...
@@ -79,7 +84,8 @@ struct TensorMapArgs {
...
@@ -79,7 +84,8 @@ struct TensorMapArgs {
<<
"globalDim "
<<
ArrayToStr
(
globalDim
,
tensorRank
)
<<
std
::
endl
<<
"globalDim "
<<
ArrayToStr
(
globalDim
,
tensorRank
)
<<
std
::
endl
<<
"globalStrides "
<<
ArrayToStr
(
globalStride
,
tensorRank
)
<<
std
::
endl
<<
"globalStrides "
<<
ArrayToStr
(
globalStride
,
tensorRank
)
<<
std
::
endl
<<
"boxDim "
<<
ArrayToStr
(
boxDim
,
tensorRank
)
<<
std
::
endl
<<
"boxDim "
<<
ArrayToStr
(
boxDim
,
tensorRank
)
<<
std
::
endl
<<
"elementStrides "
<<
ArrayToStr
(
elementStrides
,
tensorRank
)
<<
std
::
endl
<<
"elementStrides "
<<
ArrayToStr
(
elementStrides
,
tensorRank
)
<<
std
::
endl
<<
"interleave "
<<
interleave
<<
std
::
endl
<<
"interleave "
<<
interleave
<<
std
::
endl
<<
"swizzle "
<<
swizzle
<<
std
::
endl
<<
"swizzle "
<<
swizzle
<<
std
::
endl
<<
"l2Promotion "
<<
l2Promotion
<<
std
::
endl
<<
"l2Promotion "
<<
l2Promotion
<<
std
::
endl
...
@@ -89,23 +95,26 @@ struct TensorMapArgs {
...
@@ -89,23 +95,26 @@ struct TensorMapArgs {
};
};
// set device api
// set device api
TVM_REGISTER_GLOBAL
(
tvm_tensormap_create_tiled
).
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TVM_REGISTER_GLOBAL
(
tvm_tensormap_create_tiled
)
TensorMapArgs
T
=
TensorMapArgs
::
Extract
(
args
);
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CUresult
result
=
cuTensorMapEncodeTiled
(
TensorMapArgs
T
=
TensorMapArgs
::
Extract
(
args
);
T
.
map
,
T
.
type
,
T
.
tensorRank
,
T
.
globalAddress
,
T
.
globalDim
,
T
.
globalStride
+
1
,
T
.
boxDim
,
CUresult
result
=
cuTensorMapEncodeTiled
(
T
.
elementStrides
,
T
.
interleave
,
T
.
swizzle
,
T
.
l2Promotion
,
T
.
oobFill
);
T
.
map
,
T
.
type
,
T
.
tensorRank
,
T
.
globalAddress
,
T
.
globalDim
,
if
(
result
!=
CUDA_SUCCESS
)
{
T
.
globalStride
+
1
,
T
.
boxDim
,
T
.
elementStrides
,
T
.
interleave
,
LOG_FATAL
<<
"Failed to initialize the TMA descriptor "
<<
result
<<
std
::
endl
T
.
swizzle
,
T
.
l2Promotion
,
T
.
oobFill
);
<<
T
.
ToDebugString
();
if
(
result
!=
CUDA_SUCCESS
)
{
}
LOG_FATAL
<<
"Failed to initialize the TMA descriptor "
<<
result
*
ret
=
static_cast
<
int
>
(
result
);
<<
std
::
endl
});
<<
T
.
ToDebugString
();
}
*
ret
=
static_cast
<
int
>
(
result
);
});
struct
TensorMapIm2ColArgs
{
struct
TensorMapIm2ColArgs
{
CUtensorMap
*
map
;
CUtensorMap
*
map
;
CUtensorMapDataType
type
;
CUtensorMapDataType
type
;
cuuint32_t
tensorRank
;
cuuint32_t
tensorRank
;
void
*
globalAddress
;
void
*
globalAddress
;
cuuint64_t
globalDim
[
5
],
globalStride
[
5
];
cuuint64_t
globalDim
[
5
],
globalStride
[
5
];
cuuint32_t
elementStrides
[
5
];
cuuint32_t
elementStrides
[
5
];
int
pixelBoxLowerCorner
[
3
],
pixelBoxUpperCorner
[
3
];
int
pixelBoxLowerCorner
[
3
],
pixelBoxUpperCorner
[
3
];
...
@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs {
...
@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs {
TensorMapIm2ColArgs
T
;
TensorMapIm2ColArgs
T
;
int
idx
=
0
;
int
idx
=
0
;
ICHECK
(
args
.
num_args
>=
8
);
ICHECK
(
args
.
num_args
>=
8
);
T
.
map
=
reinterpret_cast
<
CUtensorMap
*>
(
static_cast
<
void
*>
(
args
[
idx
++
]));
T
.
map
=
reinterpret_cast
<
CUtensorMap
*>
(
static_cast
<
void
*>
(
args
[
idx
++
]));
T
.
type
=
static_cast
<
CUtensorMapDataType
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
type
=
static_cast
<
CUtensorMapDataType
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
tensorRank
=
static_cast
<
cuuint32_t
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
tensorRank
=
static_cast
<
cuuint32_t
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
globalAddress
=
args
[
idx
++
];
T
.
globalAddress
=
args
[
idx
++
];
ICHECK
(
T
.
tensorRank
>=
3
&&
T
.
tensorRank
<=
5
);
ICHECK
(
T
.
tensorRank
>=
3
&&
T
.
tensorRank
<=
5
);
...
@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs {
...
@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs {
}
}
T
.
smem_box_pixel
=
static_cast
<
cuuint64_t
>
(
args
[
idx
++
]);
T
.
smem_box_pixel
=
static_cast
<
cuuint64_t
>
(
args
[
idx
++
]);
T
.
smem_box_channel
=
static_cast
<
cuuint64_t
>
(
args
[
idx
++
]);
T
.
smem_box_channel
=
static_cast
<
cuuint64_t
>
(
args
[
idx
++
]);
T
.
interleave
=
static_cast
<
CUtensorMapInterleave
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
interleave
=
T
.
swizzle
=
static_cast
<
CUtensorMapSwizzle
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
static_cast
<
CUtensorMapInterleave
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
l2Promotion
=
static_cast
<
CUtensorMapL2promotion
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
swizzle
=
T
.
oobFill
=
static_cast
<
CUtensorMapFloatOOBfill
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
static_cast
<
CUtensorMapSwizzle
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
l2Promotion
=
static_cast
<
CUtensorMapL2promotion
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
oobFill
=
static_cast
<
CUtensorMapFloatOOBfill
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
return
T
;
return
T
;
}
}
...
@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs {
...
@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs {
<<
"globalStrides "
<<
ArrayToStr
(
globalStride
,
tensorRank
)
<<
std
::
endl
<<
"globalStrides "
<<
ArrayToStr
(
globalStride
,
tensorRank
)
<<
std
::
endl
<<
"smem_box_pixel "
<<
smem_box_pixel
<<
std
::
endl
<<
"smem_box_pixel "
<<
smem_box_pixel
<<
std
::
endl
<<
"smem_box_channel "
<<
smem_box_channel
<<
std
::
endl
<<
"smem_box_channel "
<<
smem_box_channel
<<
std
::
endl
<<
"pixelBoxLowerCorner "
<<
ArrayToStr
(
pixelBoxLowerCorner
,
tensorRank
-
2
)
<<
std
::
endl
<<
"pixelBoxLowerCorner "
<<
"pixelBoxUpperCorner "
<<
ArrayToStr
(
pixelBoxUpperCorner
,
tensorRank
-
2
)
<<
std
::
endl
<<
ArrayToStr
(
pixelBoxLowerCorner
,
tensorRank
-
2
)
<<
std
::
endl
<<
"elementStrides "
<<
ArrayToStr
(
elementStrides
,
tensorRank
)
<<
std
::
endl
<<
"pixelBoxUpperCorner "
<<
ArrayToStr
(
pixelBoxUpperCorner
,
tensorRank
-
2
)
<<
std
::
endl
<<
"elementStrides "
<<
ArrayToStr
(
elementStrides
,
tensorRank
)
<<
std
::
endl
<<
"interleave "
<<
interleave
<<
std
::
endl
<<
"interleave "
<<
interleave
<<
std
::
endl
<<
"swizzle "
<<
swizzle
<<
std
::
endl
<<
"swizzle "
<<
swizzle
<<
std
::
endl
<<
"l2Promotion "
<<
l2Promotion
<<
std
::
endl
<<
"l2Promotion "
<<
l2Promotion
<<
std
::
endl
...
@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs {
...
@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs {
}
}
};
};
TVM_REGISTER_GLOBAL
(
tvm_tensormap_create_im2col
).
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TVM_REGISTER_GLOBAL
(
tvm_tensormap_create_im2col
)
TensorMapIm2ColArgs
T
=
TensorMapIm2ColArgs
::
Extract
(
args
);
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CUresult
result
=
cuTensorMapEncodeIm2col
(
TensorMapIm2ColArgs
T
=
TensorMapIm2ColArgs
::
Extract
(
args
);
T
.
map
,
T
.
type
,
T
.
tensorRank
,
T
.
globalAddress
,
T
.
globalDim
,
T
.
globalStride
+
1
,
CUresult
result
=
cuTensorMapEncodeIm2col
(
T
.
pixelBoxLowerCorner
,
T
.
pixelBoxUpperCorner
,
T
.
smem_box_channel
,
T
.
smem_box_pixel
,
T
.
map
,
T
.
type
,
T
.
tensorRank
,
T
.
globalAddress
,
T
.
globalDim
,
T
.
elementStrides
,
T
.
interleave
,
T
.
swizzle
,
T
.
l2Promotion
,
T
.
oobFill
);
T
.
globalStride
+
1
,
T
.
pixelBoxLowerCorner
,
T
.
pixelBoxUpperCorner
,
if
(
result
!=
CUDA_SUCCESS
)
{
T
.
smem_box_channel
,
T
.
smem_box_pixel
,
T
.
elementStrides
,
T
.
interleave
,
LOG_FATAL
<<
"Failed to initialize the TMA descriptor "
<<
result
<<
std
::
endl
T
.
swizzle
,
T
.
l2Promotion
,
T
.
oobFill
);
<<
T
.
ToDebugString
();
if
(
result
!=
CUDA_SUCCESS
)
{
}
LOG_FATAL
<<
"Failed to initialize the TMA descriptor "
<<
result
*
ret
=
static_cast
<
int
>
(
result
);
<<
std
::
endl
});
<<
T
.
ToDebugString
();
}
*
ret
=
static_cast
<
int
>
(
result
);
});
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/runtime/runtime.h
View file @
549416f7
...
@@ -13,9 +13,11 @@
...
@@ -13,9 +13,11 @@
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
constexpr
const
char
*
tvm_tensormap_create_tiled
=
"__tvm_tensormap_create_tiled"
;
constexpr
const
char
*
tvm_tensormap_create_tiled
=
constexpr
const
char
*
tvm_tensormap_create_im2col
=
"__tvm_tensormap_create_im2col"
;
"__tvm_tensormap_create_tiled"
;
}
// namespace tl
constexpr
const
char
*
tvm_tensormap_create_im2col
=
}
// namespace tvm
"__tvm_tensormap_create_im2col"
;
}
// namespace tl
}
// namespace tvm
#endif // TVM_TL_RUNTIME_RUNTIME_H_
#endif // TVM_TL_RUNTIME_RUNTIME_H_
\ No newline at end of file
\ No newline at end of file
src/target/codegen_cuda.cc
View file @
549416f7
This diff is collapsed.
Click to expand it.
src/target/codegen_cuda.h
View file @
549416f7
...
@@ -21,50 +21,58 @@ namespace tvm {
...
@@ -21,50 +21,58 @@ namespace tvm {
namespace
codegen
{
namespace
codegen
{
class
CodeGenTileLangCUDA
final
:
public
CodeGenC
{
class
CodeGenTileLangCUDA
final
:
public
CodeGenC
{
public:
public:
CodeGenTileLangCUDA
();
CodeGenTileLangCUDA
();
std
::
string
Finish
();
std
::
string
Finish
();
// override behavior
// override behavior
void
PrintFuncPrefix
(
std
::
ostream
&
os
)
final
;
void
PrintFuncPrefix
(
std
::
ostream
&
os
)
final
;
void
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
final
;
void
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
PrintStorageSync
(
const
CallNode
*
op
)
final
;
void
PrintStorageSync
(
const
CallNode
*
op
)
final
;
void
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintStorageScope
(
const
std
::
string
&
scope
,
void
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
PrimExpr
rhs
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
void
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
PrimExpr
rhs
,
void
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
const
std
::
string
&
value
)
final
;
void
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
void
BindThreadIndex
(
const
IterVar
&
iv
)
final
;
// NOLINT(*)
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
final
;
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
string
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
final
;
const
std
::
string
&
value
)
final
;
void
BindThreadIndex
(
const
IterVar
&
iv
)
final
;
// NOLINT(*)
void
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
final
;
std
::
string
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
final
;
// overload visitor
// overload visitor
void
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
;
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
;
// Override this as a work around for __grid_constant__ parameter
// Override this as a work around for __grid_constant__ parameter
void
AddFunction
(
const
PrimFunc
&
f
);
void
AddFunction
(
const
PrimFunc
&
f
);
protected:
protected:
virtual
std
::
string
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
final
;
virtual
std
::
string
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
void
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>&
args
,
PrimExpr
index
)
final
;
bool
skip_first_arg
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
private:
private:
// Handle volatile loads
// Handle volatile loads
void
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
void
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
final
;
std
::
ostream
&
os
)
final
;
// Whether scope such as "__shared__" or "__constant__" is part of type.
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool
IsScopePartOfType
()
const
final
{
return
false
;
}
bool
IsScopePartOfType
()
const
final
{
return
false
;
}
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangCUDA
*
p
);
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangCUDA
*
p
);
// The size of the barrier array in shared memory
// The size of the barrier array in shared memory
int
barrier_count_
=
-
1
;
int
barrier_count_
=
-
1
;
// whether need mma.h
// whether need mma.h
...
@@ -77,15 +85,17 @@ class CodeGenTileLangCUDA final : public CodeGenC {
...
@@ -77,15 +85,17 @@ class CodeGenTileLangCUDA final : public CodeGenC {
// Set to 16 to maintain minimum alignment requirements for async bulk copy
// Set to 16 to maintain minimum alignment requirements for async bulk copy
const
int
barrier_alignment_bytes_
=
16
;
const
int
barrier_alignment_bytes_
=
16
;
std
::
unordered_map
<
const
VarNode
*
,
std
::
string
>
fragment_shapes
;
std
::
unordered_map
<
const
VarNode
*
,
std
::
string
>
fragment_shapes
;
std
::
unordered_map
<
const
VarNode
*
,
std
::
string
>
fragment_layouts
;
std
::
unordered_map
<
const
VarNode
*
,
std
::
string
>
fragment_layouts
;
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangCUDA
*
p
);
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
void
PrintWmmaScope
(
const
std
::
string
&
scope
,
DataType
t
,
const
VarNode
*
variable
,
CodeGenTileLangCUDA
*
p
);
std
::
ostream
&
os
);
void
PrintWmmaScope
(
const
std
::
string
&
scope
,
DataType
t
,
int32_t
GetWmmaFragmentSize
(
const
std
::
string
&
scope
,
const
VarNode
*
variable
,
int32_t
size
);
const
VarNode
*
variable
,
std
::
ostream
&
os
);
int32_t
GetWmmaFragmentSize
(
const
std
::
string
&
scope
,
const
VarNode
*
variable
,
int32_t
size
);
};
};
}
// namespace codegen
}
// namespace codegen
}
// namespace tvm
}
// namespace tvm
#endif
// TVM_TL_TARGET_CODEGEN_CUDA_H_
#endif // TVM_TL_TARGET_CODEGEN_CUDA_H_
src/target/codegen_hip.cc
View file @
549416f7
This diff is collapsed.
Click to expand it.
src/target/codegen_hip.h
View file @
549416f7
...
@@ -21,50 +21,58 @@ namespace tvm {
...
@@ -21,50 +21,58 @@ namespace tvm {
namespace
codegen
{
namespace
codegen
{
class
CodeGenTileLangHIP
final
:
public
CodeGenC
{
class
CodeGenTileLangHIP
final
:
public
CodeGenC
{
public:
public:
CodeGenTileLangHIP
();
CodeGenTileLangHIP
();
std
::
string
Finish
();
std
::
string
Finish
();
// override behavior
// override behavior
void
PrintFuncPrefix
(
std
::
ostream
&
os
)
final
;
void
PrintFuncPrefix
(
std
::
ostream
&
os
)
final
;
void
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
final
;
void
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
PrintStorageSync
(
const
CallNode
*
op
)
final
;
void
PrintStorageSync
(
const
CallNode
*
op
)
final
;
void
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintStorageScope
(
const
std
::
string
&
scope
,
void
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
PrimExpr
rhs
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
void
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
PrimExpr
rhs
,
void
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
const
std
::
string
&
value
)
final
;
void
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
void
BindThreadIndex
(
const
IterVar
&
iv
)
final
;
// NOLINT(*)
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
final
;
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
string
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
final
;
const
std
::
string
&
value
)
final
;
void
BindThreadIndex
(
const
IterVar
&
iv
)
final
;
// NOLINT(*)
void
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
final
;
std
::
string
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
final
;
// overload visitor
// overload visitor
void
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
;
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
;
// Override this as a work around for __grid_constant__ parameter
// Override this as a work around for __grid_constant__ parameter
void
AddFunction
(
const
PrimFunc
&
f
);
void
AddFunction
(
const
PrimFunc
&
f
);
protected:
protected:
virtual
std
::
string
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
final
;
virtual
std
::
string
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
void
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>&
args
,
PrimExpr
index
)
final
;
bool
skip_first_arg
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
private:
private:
// Handle volatile loads
// Handle volatile loads
void
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
void
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
final
;
std
::
ostream
&
os
)
final
;
// Whether scope such as "__shared__" or "__constant__" is part of type.
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool
IsScopePartOfType
()
const
final
{
return
false
;
}
bool
IsScopePartOfType
()
const
final
{
return
false
;
}
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangHIP
*
p
);
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangHIP
*
p
);
// whether need math_constants.h
// whether need math_constants.h
bool
need_math_constants_h_
{
false
};
bool
need_math_constants_h_
{
false
};
...
@@ -83,7 +91,7 @@ class CodeGenTileLangHIP final : public CodeGenC {
...
@@ -83,7 +91,7 @@ class CodeGenTileLangHIP final : public CodeGenC {
const
int
barrier_alignment_bytes_
=
16
;
const
int
barrier_alignment_bytes_
=
16
;
};
};
}
// namespace codegen
}
// namespace codegen
}
// namespace tvm
}
// namespace tvm
#endif
// TVM_TL_TARGET_CODEGEN_HIP_H_
#endif // TVM_TL_TARGET_CODEGEN_HIP_H_
src/target/cuda.h
View file @
549416f7
This diff is collapsed.
Click to expand it.
src/target/rt_mod_cuda.cc
View file @
549416f7
// Copyright (c) Microsoft Corporation.
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// Licensed under the MIT License.
#include "runtime/cuda/cuda_module.h"
#include "codegen_cuda.h"
#include "codegen_cuda.h"
#include "runtime/cuda/cuda_module.h"
namespace
tvm
{
namespace
tvm
{
namespace
codegen
{
namespace
codegen
{
static
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
ExtractFuncInfo
(
const
IRModule
&
mod
)
{
static
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
ExtractFuncInfo
(
const
IRModule
&
mod
)
{
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
fmap
;
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
fmap
;
for
(
auto
kv
:
mod
->
functions
)
{
for
(
auto
kv
:
mod
->
functions
)
{
ICHECK
(
kv
.
second
->
IsInstance
<
tir
::
PrimFuncNode
>
())
<<
"Can only lower IR Module with PrimFuncs"
;
ICHECK
(
kv
.
second
->
IsInstance
<
tir
::
PrimFuncNode
>
())
<<
"Can only lower IR Module with PrimFuncs"
;
auto
f
=
Downcast
<
tir
::
PrimFunc
>
(
kv
.
second
);
auto
f
=
Downcast
<
tir
::
PrimFunc
>
(
kv
.
second
);
runtime
::
FunctionInfo
info
;
runtime
::
FunctionInfo
info
;
...
@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
...
@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info
.
arg_types
.
push_back
(
f
->
params
[
i
].
dtype
());
info
.
arg_types
.
push_back
(
f
->
params
[
i
].
dtype
());
}
}
if
(
auto
opt
=
f
->
GetAttr
<
Array
<
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
if
(
auto
opt
=
f
->
GetAttr
<
Array
<
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
info
.
launch_param_tags
.
push_back
(
tag
);
info
.
launch_param_tags
.
push_back
(
tag
);
}
}
}
}
...
@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
...
@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
cg
.
Init
(
output_ssa
);
cg
.
Init
(
output_ssa
);
for
(
auto
kv
:
mod
->
functions
)
{
for
(
auto
kv
:
mod
->
functions
)
{
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangCUDA: Can only take PrimFunc"
;
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangCUDA: Can only take PrimFunc"
;
auto
f
=
Downcast
<
PrimFunc
>
(
kv
.
second
);
auto
f
=
Downcast
<
PrimFunc
>
(
kv
.
second
);
auto
calling_conv
=
f
->
GetAttr
<
Integer
>
(
tvm
::
attr
::
kCallingConv
);
auto
calling_conv
=
f
->
GetAttr
<
Integer
>
(
tvm
::
attr
::
kCallingConv
);
ICHECK
(
calling_conv
==
CallingConv
::
kDeviceKernelLaunch
);
ICHECK
(
calling_conv
==
CallingConv
::
kDeviceKernelLaunch
);
...
@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
...
@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
}
}
std
::
string
code
=
cg
.
Finish
();
std
::
string
code
=
cg
.
Finish
();
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_postproc"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_postproc"
))
{
code
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
code
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
}
}
std
::
string
fmt
=
"ptx"
;
std
::
string
fmt
=
"ptx"
;
std
::
string
ptx
;
std
::
string
ptx
;
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_compile"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_compile"
))
{
ptx
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
ptx
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
if
(
ptx
[
0
]
!=
'/'
)
fmt
=
"cubin"
;
if
(
ptx
[
0
]
!=
'/'
)
fmt
=
"cubin"
;
}
else
{
}
else
{
ICHECK
(
0
);
ICHECK
(
0
);
}
}
...
@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) {
...
@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) {
cg
.
Init
(
output_ssa
);
cg
.
Init
(
output_ssa
);
for
(
auto
kv
:
mod
->
functions
)
{
for
(
auto
kv
:
mod
->
functions
)
{
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangCUDA: Can only take PrimFunc"
;
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangCUDA: Can only take PrimFunc"
;
auto
f
=
Downcast
<
PrimFunc
>
(
kv
.
second
);
auto
f
=
Downcast
<
PrimFunc
>
(
kv
.
second
);
auto
calling_conv
=
f
->
GetAttr
<
Integer
>
(
tvm
::
attr
::
kCallingConv
);
auto
calling_conv
=
f
->
GetAttr
<
Integer
>
(
tvm
::
attr
::
kCallingConv
);
ICHECK
(
calling_conv
==
CallingConv
::
kDeviceKernelLaunch
);
ICHECK
(
calling_conv
==
CallingConv
::
kDeviceKernelLaunch
);
...
@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) {
...
@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) {
}
}
std
::
string
code
=
cg
.
Finish
();
std
::
string
code
=
cg
.
Finish
();
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_postproc"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_postproc"
))
{
code
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
code
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
}
}
return
String
(
code
);
return
String
(
code
);
}
}
TVM_REGISTER_GLOBAL
(
"target.build.tilelang_cuda"
).
set_body_typed
(
BuildTileLangCUDA
);
TVM_REGISTER_GLOBAL
(
"target.build.tilelang_cuda"
)
TVM_REGISTER_GLOBAL
(
"target.build.tl_debug_codegen"
).
set_body_typed
(
BuildTLDebug
);
.
set_body_typed
(
BuildTileLangCUDA
);
TVM_REGISTER_GLOBAL
(
"target.build.tl_debug_codegen"
)
.
set_body_typed
(
BuildTLDebug
);
}
// namespace codegen
}
// namespace codegen
}
// namespace tvm
}
// namespace tvm
src/target/rt_mod_hip.cc
View file @
549416f7
// Copyright (c) Microsoft Corporation.
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// Licensed under the MIT License.
#if defined(__linux__)
#if defined(__linux__)
#include <sys/stat.h>
#include <sys/stat.h>
#endif
#endif
...
@@ -8,28 +8,28 @@
...
@@ -8,28 +8,28 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hiprtc.h>
#include <hip/hiprtc.h>
#include "runtime/rocm/rocm_module.h"
#include "codegen_hip.h"
#include "codegen_hip.h"
#include "runtime/rocm/rocm_module.h"
namespace
tvm
{
namespace
tvm
{
namespace
codegen
{
namespace
codegen
{
#define HIPRTC_CALL(x) \
#define HIPRTC_CALL(x) \
\
\
{
\
{
\
\
\
hiprtcResult
result
=
x
;
\
hiprtcResult
result
=
x
;
\
\
\
if
(
result
!=
HIPRTC_SUCCESS
)
{
\
if
(
result
!=
HIPRTC_SUCCESS
)
{
\
\
\
LOG
(
FATAL
)
\
LOG
(
FATAL
)
\
<<
"HiprtcError: "
#
x
" failed with error: "
<<
hiprtcGetErrorString
(
result
);
\
<<
"HiprtcError: "
#
x
" failed with error: "
\
<<
hiprtcGetErrorString
(
result
);
\
\
\
\
\
}
\
}
\
\
\
\
\
}
}
static
std
::
string
FindHIPIncludePath
()
{
static
std
::
string
FindHIPIncludePath
()
{
...
@@ -39,7 +39,7 @@ static std::string FindHIPIncludePath() {
...
@@ -39,7 +39,7 @@ static std::string FindHIPIncludePath() {
const
std
::
string
delimiter
=
"/"
;
const
std
::
string
delimiter
=
"/"
;
#endif
#endif
std
::
string
hip_include_path
;
std
::
string
hip_include_path
;
const
char
*
hip_path_env
=
std
::
getenv
(
"HIP_PATH"
);
const
char
*
hip_path_env
=
std
::
getenv
(
"HIP_PATH"
);
if
(
hip_path_env
!=
nullptr
)
{
if
(
hip_path_env
!=
nullptr
)
{
hip_include_path
+=
hip_path_env
;
hip_include_path
+=
hip_path_env
;
hip_include_path
+=
delimiter
+
"include"
;
hip_include_path
+=
delimiter
+
"include"
;
...
@@ -58,19 +58,24 @@ static std::string FindHIPIncludePath() {
...
@@ -58,19 +58,24 @@ static std::string FindHIPIncludePath() {
}
}
#endif
#endif
LOG
(
FATAL
)
<<
"Cannot find HIP include path."
LOG
(
FATAL
)
<<
"Cannot find HIP include path."
<<
"HIP_PATH is not set or ROCm is not installed in the default installation path."
<<
"HIP_PATH is not set or ROCm is not installed in the default "
"installation path."
<<
"In other than linux, it is necessary to set HIP_PATH."
;
<<
"In other than linux, it is necessary to set HIP_PATH."
;
return
hip_include_path
;
return
hip_include_path
;
}
}
static
std
::
string
HIPRTCCompile
(
const
std
::
string
&
code
,
bool
include_path
=
false
)
{
static
std
::
string
HIPRTCCompile
(
const
std
::
string
&
code
,
bool
include_path
=
false
)
{
std
::
vector
<
std
::
string
>
compile_params
;
std
::
vector
<
std
::
string
>
compile_params
;
std
::
vector
<
const
char
*>
param_cstrings
{};
std
::
vector
<
const
char
*>
param_cstrings
{};
hiprtcProgram
prog
;
hiprtcProgram
prog
;
std
::
string
cc
=
"gfx900"
;
// Default target architecture (can be changed as needed)
std
::
string
cc
=
"gfx900"
;
// Default target architecture (can be changed as needed)
int
major
,
minor
;
int
major
,
minor
;
hipError_t
e1
=
hipDeviceGetAttribute
(
&
major
,
hipDeviceAttributeComputeCapabilityMajor
,
0
);
hipError_t
e1
=
hipDeviceGetAttribute
(
hipError_t
e2
=
hipDeviceGetAttribute
(
&
minor
,
hipDeviceAttributeComputeCapabilityMinor
,
0
);
&
major
,
hipDeviceAttributeComputeCapabilityMajor
,
0
);
hipError_t
e2
=
hipDeviceGetAttribute
(
&
minor
,
hipDeviceAttributeComputeCapabilityMinor
,
0
);
if
(
e1
==
hipSuccess
&&
e2
==
hipSuccess
)
{
if
(
e1
==
hipSuccess
&&
e2
==
hipSuccess
)
{
cc
=
"gfx"
+
std
::
to_string
(
major
*
100
+
minor
*
10
);
cc
=
"gfx"
+
std
::
to_string
(
major
*
100
+
minor
*
10
);
...
@@ -86,10 +91,11 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
...
@@ -86,10 +91,11 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
compile_params
.
push_back
(
include_option
);
compile_params
.
push_back
(
include_option
);
}
}
for
(
const
auto
&
string
:
compile_params
)
{
for
(
const
auto
&
string
:
compile_params
)
{
param_cstrings
.
push_back
(
string
.
c_str
());
param_cstrings
.
push_back
(
string
.
c_str
());
}
}
HIPRTC_CALL
(
hiprtcCreateProgram
(
&
prog
,
code
.
c_str
(),
nullptr
,
0
,
nullptr
,
nullptr
));
HIPRTC_CALL
(
hiprtcCreateProgram
(
&
prog
,
code
.
c_str
(),
nullptr
,
0
,
nullptr
,
nullptr
));
hiprtcResult
compile_res
=
hiprtcResult
compile_res
=
hiprtcCompileProgram
(
prog
,
param_cstrings
.
size
(),
param_cstrings
.
data
());
hiprtcCompileProgram
(
prog
,
param_cstrings
.
size
(),
param_cstrings
.
data
());
...
@@ -110,11 +116,13 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
...
@@ -110,11 +116,13 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
return
code_out
;
return
code_out
;
}
}
static
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
ExtractFuncInfo
(
const
IRModule
&
mod
)
{
static
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
ExtractFuncInfo
(
const
IRModule
&
mod
)
{
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
fmap
;
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
fmap
;
for
(
auto
kv
:
mod
->
functions
)
{
for
(
auto
kv
:
mod
->
functions
)
{
ICHECK
(
kv
.
second
->
IsInstance
<
tir
::
PrimFuncNode
>
())
<<
"Can only lower IR Module with PrimFuncs"
;
ICHECK
(
kv
.
second
->
IsInstance
<
tir
::
PrimFuncNode
>
())
<<
"Can only lower IR Module with PrimFuncs"
;
auto
f
=
Downcast
<
tir
::
PrimFunc
>
(
kv
.
second
);
auto
f
=
Downcast
<
tir
::
PrimFunc
>
(
kv
.
second
);
runtime
::
FunctionInfo
info
;
runtime
::
FunctionInfo
info
;
...
@@ -129,7 +137,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
...
@@ -129,7 +137,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info
.
arg_types
.
push_back
(
f
->
params
[
i
].
dtype
());
info
.
arg_types
.
push_back
(
f
->
params
[
i
].
dtype
());
}
}
if
(
auto
opt
=
f
->
GetAttr
<
Array
<
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
if
(
auto
opt
=
f
->
GetAttr
<
Array
<
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
info
.
launch_param_tags
.
push_back
(
tag
);
info
.
launch_param_tags
.
push_back
(
tag
);
}
}
}
}
...
@@ -146,7 +154,8 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
...
@@ -146,7 +154,8 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
cg
.
Init
(
output_ssa
);
cg
.
Init
(
output_ssa
);
for
(
auto
kv
:
mod
->
functions
)
{
for
(
auto
kv
:
mod
->
functions
)
{
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangHIP: Can only take PrimFunc"
;
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangHIP: Can only take PrimFunc"
;
auto
f
=
Downcast
<
PrimFunc
>
(
kv
.
second
);
auto
f
=
Downcast
<
PrimFunc
>
(
kv
.
second
);
auto
calling_conv
=
f
->
GetAttr
<
Integer
>
(
tvm
::
attr
::
kCallingConv
);
auto
calling_conv
=
f
->
GetAttr
<
Integer
>
(
tvm
::
attr
::
kCallingConv
);
ICHECK
(
calling_conv
==
CallingConv
::
kDeviceKernelLaunch
);
ICHECK
(
calling_conv
==
CallingConv
::
kDeviceKernelLaunch
);
...
@@ -154,21 +163,23 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
...
@@ -154,21 +163,23 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
}
}
std
::
string
code
=
cg
.
Finish
();
std
::
string
code
=
cg
.
Finish
();
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_hip_postproc"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_hip_postproc"
))
{
code
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
code
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
}
}
std
::
string
fmt
=
"ptx"
;
std
::
string
fmt
=
"ptx"
;
std
::
string
ptx
;
std
::
string
ptx
;
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_hip_compile"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_hip_compile"
))
{
ptx
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
ptx
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
if
(
ptx
[
0
]
!=
'/'
)
fmt
=
"hsaco"
;
if
(
ptx
[
0
]
!=
'/'
)
fmt
=
"hsaco"
;
}
else
{
}
else
{
ptx
=
HIPRTCCompile
(
code
,
false
);
ptx
=
HIPRTCCompile
(
code
,
false
);
}
}
return
ROCMModuleCreate
(
ptx
,
fmt
,
ExtractFuncInfo
(
mod
),
code
,
std
::
string
());
return
ROCMModuleCreate
(
ptx
,
fmt
,
ExtractFuncInfo
(
mod
),
code
,
std
::
string
());
}
}
TVM_REGISTER_GLOBAL
(
"target.build.tilelang_hip"
).
set_body_typed
(
BuildTileLangHIP
);
TVM_REGISTER_GLOBAL
(
"target.build.tilelang_hip"
)
.
set_body_typed
(
BuildTileLangHIP
);
}
// namespace codegen
}
// namespace codegen
}
// namespace tvm
}
// namespace tvm
src/target/utils.cc
View file @
549416f7
...
@@ -11,13 +11,17 @@
...
@@ -11,13 +11,17 @@
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
bool
TargetIsCuda
(
Target
target
)
{
return
target
->
GetTargetDeviceType
()
==
kDLCUDA
;
}
bool
TargetIsCuda
(
Target
target
)
{
bool
TargetIsRocm
(
Target
target
)
{
return
target
->
GetTargetDeviceType
()
==
kDLROCM
;
}
return
target
->
GetTargetDeviceType
()
==
kDLCUDA
;
}
bool
TargetIsRocm
(
Target
target
)
{
return
target
->
GetTargetDeviceType
()
==
kDLROCM
;
}
int
GetArchInt
(
Target
target
)
{
int
GetArchInt
(
Target
target
)
{
auto
s
=
target
->
GetAttr
<
String
>
(
"arch"
);
auto
s
=
target
->
GetAttr
<
String
>
(
"arch"
);
ICHECK
(
s
.
defined
());
ICHECK
(
s
.
defined
());
const
char
*
arch_str
=
s
.
value
().
c_str
();
const
char
*
arch_str
=
s
.
value
().
c_str
();
ICHECK_EQ
(
arch_str
[
0
],
's'
);
ICHECK_EQ
(
arch_str
[
0
],
's'
);
ICHECK_EQ
(
arch_str
[
1
],
'm'
);
ICHECK_EQ
(
arch_str
[
1
],
'm'
);
ICHECK_EQ
(
arch_str
[
2
],
'_'
);
ICHECK_EQ
(
arch_str
[
2
],
'_'
);
...
@@ -25,31 +29,36 @@ int GetArchInt(Target target) {
...
@@ -25,31 +29,36 @@ int GetArchInt(Target target) {
}
}
bool
TargetIsVolta
(
Target
target
)
{
bool
TargetIsVolta
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
int
arch
=
GetArchInt
(
target
);
return
arch
>=
70
&&
arch
<
75
;
return
arch
>=
70
&&
arch
<
75
;
}
}
bool
TargetIsTuring
(
Target
target
)
{
bool
TargetIsTuring
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
int
arch
=
GetArchInt
(
target
);
return
arch
>=
75
&&
arch
<
80
;
return
arch
>=
75
&&
arch
<
80
;
}
}
bool
TargetIsAmpere
(
Target
target
)
{
bool
TargetIsAmpere
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
int
arch
=
GetArchInt
(
target
);
return
arch
>=
80
&&
arch
<
90
;
return
arch
>=
80
&&
arch
<
90
;
}
}
bool
TargetIsHopper
(
Target
target
)
{
bool
TargetIsHopper
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
int
arch
=
GetArchInt
(
target
);
return
arch
>=
90
;
return
arch
>=
90
;
}
}
bool
TargetIsCDNA
(
Target
target
)
{
bool
TargetIsCDNA
(
Target
target
)
{
if
(
!
TargetIsRocm
(
target
))
return
false
;
if
(
!
TargetIsRocm
(
target
))
return
false
;
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
std
::
string
mcpu
=
Downcast
<
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
std
::
string
mcpu
=
Downcast
<
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
// if mcpu start with "gfx9", it is CDNA
// if mcpu start with "gfx9", it is CDNA
...
@@ -78,16 +87,18 @@ bool TargetHasAsyncCopy(Target target) {
...
@@ -78,16 +87,18 @@ bool TargetHasAsyncCopy(Target target) {
return
false
;
return
false
;
}
}
bool
TargetHasLdmatrix
(
Target
target
)
{
bool
TargetHasLdmatrix
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
int
arch
=
GetArchInt
(
target
);
return
arch
>=
75
;
return
arch
>=
75
;
}
}
bool
TargetHasStmatrix
(
Target
target
)
{
bool
TargetHasStmatrix
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
int
arch
=
GetArchInt
(
target
);
return
arch
>=
90
;
return
arch
>=
90
;
}
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/target/utils.h
View file @
549416f7
...
@@ -23,12 +23,12 @@ bool TargetIsTuring(Target target);
...
@@ -23,12 +23,12 @@ bool TargetIsTuring(Target target);
bool
TargetIsAmpere
(
Target
target
);
bool
TargetIsAmpere
(
Target
target
);
bool
TargetIsHopper
(
Target
target
);
bool
TargetIsHopper
(
Target
target
);
bool
TargetIsCDNA
(
Target
target
);
bool
TargetIsCDNA
(
Target
target
);
bool
TargetHasAsyncCopy
(
Target
target
);
bool
TargetHasAsyncCopy
(
Target
target
);
bool
TargetHasLdmatrix
(
Target
target
);
bool
TargetHasLdmatrix
(
Target
target
);
bool
TargetHasStmatrix
(
Target
target
);
bool
TargetHasStmatrix
(
Target
target
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif
// TVM_TL_TARGET_UTILS_H_
#endif // TVM_TL_TARGET_UTILS_H_
src/tl_templates/cuda/common.h
View file @
549416f7
...
@@ -25,56 +25,57 @@ using cutlass::tfloat32_t;
...
@@ -25,56 +25,57 @@ using cutlass::tfloat32_t;
// Pack two half values.
// Pack two half values.
TL_DEVICE
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
TL_DEVICE
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
return
(
v1
<<
16
)
|
v0
;
}
}
// Pack two half_t values.
// Pack two half_t values.
TL_DEVICE
unsigned
__pack_half2
(
const
half_t
x
,
const
half_t
y
)
{
TL_DEVICE
unsigned
__pack_half2
(
const
half_t
x
,
const
half_t
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
return
(
v1
<<
16
)
|
v0
;
}
}
// Pack two bfloat16_t values.
// Pack two bfloat16_t values.
TL_DEVICE
unsigned
__pack_half2
(
const
bfloat16_t
x
,
const
bfloat16_t
y
)
{
TL_DEVICE
unsigned
__pack_half2
(
const
bfloat16_t
x
,
const
bfloat16_t
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
return
(
v1
<<
16
)
|
v0
;
}
}
/// Helper to cast SMEM pointer to unsigned
/// Helper to cast SMEM pointer to unsigned
TL_DEVICE
uint32_t
smem_ptr_to_uint
(
void
const
*
const
ptr
)
{
TL_DEVICE
uint32_t
smem_ptr_to_uint
(
void
const
*
const
ptr
)
{
return
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
ptr
));
return
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
ptr
));
}
}
// AtomicAdd Functions for FP16
// AtomicAdd Functions for FP16
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
half_t
val
)
{
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
half_t
val
)
{
// Use atomicCAS with built-in cuda_fp16 support
// Use atomicCAS with built-in cuda_fp16 support
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
static_cast
<
half
>
(
val
));
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
static_cast
<
half
>
(
val
));
}
}
// AtomicAdd Functions for FP16
// AtomicAdd Functions for FP16
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
half_t
*
val
)
{
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
half_t
*
val
)
{
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
static_cast
<
half
>
(
*
val
));
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
static_cast
<
half
>
(
*
val
));
}
}
// AtomicAdd Functions for FP16
// AtomicAdd Functions for FP16
TL_DEVICE
void
atomicAddx2
(
half_t
*
address
,
half_t
*
val
)
{
TL_DEVICE
void
atomicAddx2
(
half_t
*
address
,
half_t
*
val
)
{
atomicAdd
(
reinterpret_cast
<
half2
*>
(
address
),
static_cast
<
half2
>
(
*
reinterpret_cast
<
half2
*>
(
val
)));
atomicAdd
(
reinterpret_cast
<
half2
*>
(
address
),
static_cast
<
half2
>
(
*
reinterpret_cast
<
half2
*>
(
val
)));
}
}
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
float
val
)
{
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
float
val
)
{
// Use atomicCAS with built-in cuda_fp16 support
// Use atomicCAS with built-in cuda_fp16 support
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
__float2half
(
val
));
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
__float2half
(
val
));
}
}
// DP4A
// DP4A
template
<
typename
InDatatype
,
typename
OutDatatype
>
template
<
typename
InDatatype
,
typename
OutDatatype
>
TL_DEVICE
void
DP4A
(
InDatatype
*
a
,
InDatatype
*
b
,
OutDatatype
*
c
)
{
TL_DEVICE
void
DP4A
(
InDatatype
*
a
,
InDatatype
*
b
,
OutDatatype
*
c
)
{
const
int
a_int
=
*
((
int
*
)
a
);
const
int
a_int
=
*
((
int
*
)
a
);
const
int
b_int
=
*
((
int
*
)
b
);
const
int
b_int
=
*
((
int
*
)
b
);
const
int
c_int
=
*
((
int
*
)
c
);
const
int
c_int
=
*
((
int
*
)
c
);
*
c
=
__dp4a
(
a_int
,
b_int
,
c_int
);
*
c
=
__dp4a
(
a_int
,
b_int
,
c_int
);
}
}
src/tl_templates/cuda/copy.h
View file @
549416f7
...
@@ -10,10 +10,11 @@
...
@@ -10,10 +10,11 @@
namespace
tl
{
namespace
tl
{
TL_DEVICE
void
cp_async_commit
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
TL_DEVICE
void
cp_async_commit
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
template
<
int
N
>
template
<
int
N
>
TL_DEVICE
void
cp_async_wait
()
{
TL_DEVICE
void
cp_async_wait
()
{
if
constexpr
(
N
==
0
)
{
if
constexpr
(
N
==
0
)
{
asm
volatile
(
"cp.async.wait_all;
\n
"
::
);
asm
volatile
(
"cp.async.wait_all;
\n
"
::
);
}
else
{
}
else
{
...
@@ -22,7 +23,7 @@ TL_DEVICE void cp_async_wait() {
...
@@ -22,7 +23,7 @@ TL_DEVICE void cp_async_wait() {
}
}
template
<
int
N
>
template
<
int
N
>
TL_DEVICE
void
cp_async_gs
(
void
const
*
const
smem_addr
,
void
*
global_ptr
)
{
TL_DEVICE
void
cp_async_gs
(
void
const
*
const
smem_addr
,
void
*
global_ptr
)
{
static_assert
(
N
==
16
||
N
==
8
||
N
==
4
);
static_assert
(
N
==
16
||
N
==
8
||
N
==
4
);
unsigned
int
addr
=
smem_ptr_to_uint
(
smem_addr
);
unsigned
int
addr
=
smem_ptr_to_uint
(
smem_addr
);
if
constexpr
(
N
==
16
)
{
if
constexpr
(
N
==
16
)
{
...
@@ -33,7 +34,7 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
...
@@ -33,7 +34,7 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
"cp.async.cg.shared.global [%0], [%1], %2;"
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
#endif
::
"r"
(
addr
),
::
"r"
(
addr
),
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
));
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
));
}
else
{
}
else
{
__asm__
__volatile__
(
__asm__
__volatile__
(
#if TL_ENABLE_L2_PREFETCH
#if TL_ENABLE_L2_PREFETCH
...
@@ -42,12 +43,13 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
...
@@ -42,12 +43,13 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
"cp.async.ca.shared.global [%0], [%1], %2;"
"cp.async.ca.shared.global [%0], [%1], %2;"
#endif
#endif
::
"r"
(
addr
),
::
"r"
(
addr
),
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
));
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
));
}
}
}
}
template
<
int
N
>
template
<
int
N
>
TL_DEVICE
void
cp_async_gs_conditional
(
void
const
*
const
smem_addr
,
void
*
global_ptr
,
bool
cond
)
{
TL_DEVICE
void
cp_async_gs_conditional
(
void
const
*
const
smem_addr
,
void
*
global_ptr
,
bool
cond
)
{
static_assert
(
N
==
16
||
N
==
8
||
N
==
4
);
static_assert
(
N
==
16
||
N
==
8
||
N
==
4
);
int
bytes
=
cond
?
N
:
0
;
int
bytes
=
cond
?
N
:
0
;
unsigned
int
addr
=
smem_ptr_to_uint
(
smem_addr
);
unsigned
int
addr
=
smem_ptr_to_uint
(
smem_addr
);
...
@@ -59,7 +61,7 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
...
@@ -59,7 +61,7 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
"cp.async.cg.shared.global [%0], [%1], %2, %3;"
"cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif
#endif
::
"r"
(
addr
),
::
"r"
(
addr
),
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
),
"r"
(
bytes
));
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
),
"r"
(
bytes
));
}
else
{
}
else
{
__asm__
__volatile__
(
__asm__
__volatile__
(
#if TL_ENABLE_L2_PREFETCH
#if TL_ENABLE_L2_PREFETCH
...
@@ -68,8 +70,8 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
...
@@ -68,8 +70,8 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
"cp.async.ca.shared.global [%0], [%1], %2, %3;"
"cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif
#endif
::
"r"
(
addr
),
::
"r"
(
addr
),
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
),
"r"
(
bytes
));
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
),
"r"
(
bytes
));
}
}
}
}
}
// namespace tl
}
// namespace tl
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment