Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
549416f7
Commit
549416f7
authored
Jan 11, 2025
by
LeiWang1999
Browse files
Merge branch 'main' of
https://github.com/microsoft/TileLang
into main
parents
4d63633a
7fad4e88
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
11122 additions
and
8555 deletions
+11122
-8555
src/op/gemm.h
src/op/gemm.h
+9
-9
src/op/op.cc
src/op/op.cc
+18
-10
src/op/op.h
src/op/op.h
+27
-24
src/op/parallel.cc
src/op/parallel.cc
+66
-45
src/op/parallel.h
src/op/parallel.h
+14
-14
src/op/reduce.cc
src/op/reduce.cc
+84
-70
src/op/reduce.h
src/op/reduce.h
+9
-9
src/runtime/runtime.cc
src/runtime/runtime.cc
+68
-48
src/runtime/runtime.h
src/runtime/runtime.h
+7
-5
src/target/codegen_cuda.cc
src/target/codegen_cuda.cc
+472
-379
src/target/codegen_cuda.h
src/target/codegen_cuda.h
+50
-40
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+377
-307
src/target/codegen_hip.h
src/target/codegen_hip.h
+42
-34
src/target/cuda.h
src/target/cuda.h
+9756
-7470
src/target/rt_mod_cuda.cc
src/target/rt_mod_cuda.cc
+21
-14
src/target/rt_mod_hip.cc
src/target/rt_mod_hip.cc
+42
-31
src/target/utils.cc
src/target/utils.cc
+23
-12
src/target/utils.h
src/target/utils.h
+4
-4
src/tl_templates/cuda/common.h
src/tl_templates/cuda/common.h
+21
-20
src/tl_templates/cuda/copy.h
src/tl_templates/cuda/copy.h
+12
-10
No files found.
src/op/gemm.h
View file @
549416f7
...
...
@@ -18,18 +18,18 @@ namespace tl {
using
namespace
tir
;
class
Gemm
:
public
Operator
{
public:
public:
Gemm
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
final
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
static
const
Op
&
Get
();
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
final
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
static
const
Op
&
Get
();
enum
class
GemmWarpPolicy
{
kSquare
=
0
,
kFullRow
=
1
,
kFullCol
=
2
,
}
policy
;
private:
private:
std
::
pair
<
int
,
int
>
ComputeWarpPartition
(
int
num_warps
,
Target
target
)
const
;
Array
<
PrimExpr
>
call_args
;
...
...
@@ -38,11 +38,11 @@ class Gemm : public Operator {
int
M
,
N
,
K
;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int
kPack
=
1
;
int
kPack
=
1
;
bool
completed_
=
false
;
};
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
#endif // TVM_TL_OP_GEMM_H_
\ No newline at end of file
#endif // TVM_TL_OP_GEMM_H_
\ No newline at end of file
src/op/op.cc
View file @
549416f7
...
...
@@ -20,13 +20,14 @@ using namespace tir;
TIR_REGISTER_TL_OP
(
RegionOp
,
region
)
.
set_num_inputs
(
-
1
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kPure
));
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kPure
));
std
::
unique_ptr
<
Operator
>
ParseOperator
(
Call
call
,
BufferMap
vmap
)
{
auto
op_map
=
Op
::
GetAttrMap
<
OpBuilderFunc
>
(
"TLOpBuilder"
);
Op
op
=
call
->
op
.
as
<
Op
>
().
value
();
if
(
op_map
.
count
(
op
))
{
Operator
*
ptr
=
static_cast
<
Operator
*>
(
op_map
[
op
](
call
->
args
,
vmap
));
Operator
*
ptr
=
static_cast
<
Operator
*>
(
op_map
[
op
](
call
->
args
,
vmap
));
ICHECK
(
ptr
!=
nullptr
);
return
std
::
unique_ptr
<
Operator
>
(
ptr
);
}
...
...
@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
return
nullptr
;
}
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
)
{
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
)
{
auto
call
=
expr
.
as
<
CallNode
>
();
ICHECK
(
call
);
ICHECK
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
...
...
@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
bool
RegionOp
::
IsFullRegion
()
const
{
for
(
size_t
i
=
0
;
i
<
ranges_
.
size
();
i
++
)
{
if
(
!
is_zero
(
ranges_
[
i
]
->
min
))
return
false
;
if
(
!
StructuralEqual
()(
ranges_
[
i
]
->
extent
,
buffer_
->
shape
[
i
]))
return
false
;
if
(
!
is_zero
(
ranges_
[
i
]
->
min
))
return
false
;
if
(
!
StructuralEqual
()(
ranges_
[
i
]
->
extent
,
buffer_
->
shape
[
i
]))
return
false
;
}
return
true
;
}
Stmt
Operator
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
Stmt
Operator
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
ICHECK
(
0
)
<<
"Not Implemented Lower method."
;
return
Evaluate
(
0
);
}
Stmt
Operator
::
Canonialize
(
const
CanonializeArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
return
{};
}
Stmt
Operator
::
Canonialize
(
const
CanonializeArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
return
{};
}
LayoutMap
Operator
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
return
{};
}
LayoutMap
Operator
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
return
{};
}
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/op/op.h
View file @
549416f7
...
...
@@ -25,17 +25,19 @@ using namespace tir;
using
AddWorkspaceCallback
=
std
::
function
<
PrimExpr
(
int
,
DataType
)
>
;
using
LayoutMap
=
Map
<
Buffer
,
Layout
>
;
using
BufferMap
=
Map
<
Var
,
Buffer
>
;
using
OpBuilderFunc
=
TypedPackedFunc
<
void
*
(
Array
<
PrimExpr
>
,
BufferMap
)
>
;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op& Entry::Get() { \
static const Op& op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>( \
"TLOpBuilder", [](Array<PrimExpr> a, BufferMap b) { return (void*)(new Entry(a, b)); })
using
OpBuilderFunc
=
TypedPackedFunc
<
void
*
(
Array
<
PrimExpr
>
,
BufferMap
)
>
;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>("TLOpBuilder", \
[](Array<PrimExpr> a, BufferMap b) { \
return (void *)(new Entry(a, b)); \
})
enum
class
InferLevel
{
kFree
=
0
,
...
...
@@ -64,35 +66,36 @@ struct CanonializeArgs {
};
class
Operator
{
public:
virtual
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
;
virtual
Stmt
Canonialize
(
const
CanonializeArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
;
virtual
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
);
public:
virtual
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
;
virtual
Stmt
Canonialize
(
const
CanonializeArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
;
virtual
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
);
virtual
~
Operator
()
=
default
;
};
class
RegionOp
:
public
Operator
{
public:
public:
RegionOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
static
const
Op
&
Get
();
const
Buffer
&
GetBuffer
()
const
{
return
buffer_
;
}
const
Array
<
Range
>
&
GetRanges
()
const
{
return
ranges_
;
}
const
Buffer
&
GetBuffer
()
const
{
return
buffer_
;
}
const
Array
<
Range
>
&
GetRanges
()
const
{
return
ranges_
;
}
int
GetAccessMask
()
const
{
return
access_mask_
;
}
bool
IsFullRegion
()
const
;
private:
private:
Buffer
buffer_
;
Array
<
Range
>
ranges_
;
int
access_mask_
;
};
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
);
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
);
std
::
unique_ptr
<
Operator
>
ParseOperator
(
Call
call
,
BufferMap
vmap
);
std
::
unique_ptr
<
Operator
>
ParseOperator
(
Stmt
stmt
,
BufferMap
vmap
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
#endif
// TVM_TL_OP_OP_H_
#endif // TVM_TL_OP_OP_H_
src/op/parallel.cc
View file @
549416f7
...
...
@@ -39,21 +39,22 @@ using namespace tir;
namespace
attr
{
/*! \brief Mark that how the loop is vectorized. */
constexpr
const
char
*
coalesced_width
=
"coalesced_width"
;
}
}
// namespace attr
class
IfBufferRemapLoopGenerator
:
public
StmtExprMutator
{
public:
public:
static
For
run
(
Stmt
stmt
,
Map
<
Buffer
,
Buffer
>
buffer_remap
,
Map
<
Buffer
,
Layout
>
layout_map
)
{
IfBufferRemapLoopGenerator
generator
(
buffer_remap
,
layout_map
);
return
Downcast
<
For
>
(
generator
(
std
::
move
(
stmt
)));
}
private:
IfBufferRemapLoopGenerator
(
Map
<
Buffer
,
Buffer
>
buffer_remap
,
Map
<
Buffer
,
Layout
>
layout_map
)
private:
IfBufferRemapLoopGenerator
(
Map
<
Buffer
,
Buffer
>
buffer_remap
,
Map
<
Buffer
,
Layout
>
layout_map
)
:
buffer_remap_
(
buffer_remap
),
layout_map_
(
layout_map
)
{}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
Downcast
<
BufferLoad
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
if
(
buffer_remap_
.
count
(
load
->
buffer
))
{
...
...
@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
return
load
;
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
store
=
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
if
(
buffer_remap_
.
count
(
store
->
buffer
))
{
auto
new_indices
=
layout_map_
[
store
->
buffer
]
->
Forward
(
store
->
indices
);
...
...
@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
Map
<
Buffer
,
Layout
>
layout_map_
;
};
void
ParallelLoopNestVisitor
::
VisitStmt_
(
const
ForNode
*
op
)
{
void
ParallelLoopNestVisitor
::
VisitStmt_
(
const
ForNode
*
op
)
{
ICHECK
(
op
->
kind
==
ForKind
::
kParallel
);
p
->
loop_vars_
.
push_back
(
IterVar
(
Range
(
op
->
min
,
op
->
extent
),
op
->
loop_var
,
IterVarType
::
kDataPar
));
p
->
loop_vars_
.
push_back
(
IterVar
(
Range
(
op
->
min
,
op
->
extent
),
op
->
loop_var
,
IterVarType
::
kDataPar
));
p
->
analyzer_
.
Bind
(
op
->
loop_var
,
Range
::
FromMinExtent
(
op
->
min
,
op
->
extent
));
StmtExprVisitor
::
VisitStmt_
(
op
);
}
void
ParallelLoopNestVisitor
::
VisitStmt_
(
const
BufferStoreNode
*
op
)
{
void
ParallelLoopNestVisitor
::
VisitStmt_
(
const
BufferStoreNode
*
op
)
{
if
(
op
->
buffer
.
scope
()
==
"local.fragment"
)
{
if
(
p
->
indice_map_
.
find
(
op
->
buffer
)
!=
p
->
indice_map_
.
end
())
{
ICHECK
(
StructuralEqual
()(
p
->
indice_map_
.
at
(
op
->
buffer
),
op
->
indices
))
<<
op
->
buffer
<<
": "
<<
op
->
indices
<<
" and "
<<
p
->
indice_map_
.
at
(
op
->
buffer
);
<<
op
->
buffer
<<
": "
<<
op
->
indices
<<
" and "
<<
p
->
indice_map_
.
at
(
op
->
buffer
);
}
else
{
p
->
indice_map_
.
Set
(
op
->
buffer
,
op
->
indices
);
}
...
...
@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
StmtExprVisitor
::
VisitStmt_
(
op
);
}
void
ParallelLoopNestVisitor
::
VisitExpr_
(
const
BufferLoadNode
*
op
)
{
void
ParallelLoopNestVisitor
::
VisitExpr_
(
const
BufferLoadNode
*
op
)
{
if
(
op
->
buffer
.
scope
()
==
"local.fragment"
)
{
if
(
p
->
indice_map_
.
find
(
op
->
buffer
)
!=
p
->
indice_map_
.
end
())
{
ICHECK
(
StructuralEqual
()(
p
->
indice_map_
.
at
(
op
->
buffer
),
op
->
indices
))
<<
op
->
buffer
<<
": "
<<
op
->
indices
<<
" and "
<<
p
->
indice_map_
.
at
(
op
->
buffer
);
<<
op
->
buffer
<<
": "
<<
op
->
indices
<<
" and "
<<
p
->
indice_map_
.
at
(
op
->
buffer
);
}
else
{
p
->
indice_map_
.
Set
(
op
->
buffer
,
op
->
indices
);
}
...
...
@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
ParallelOp
::
ParallelOp
(
For
root
)
:
root_
(
root
),
V
(
this
)
{
V
.
VisitStmt
(
root
);
}
bool
ParallelOp
::
IsCommonAccessIndice
(
const
Buffer
&
buffer
)
const
{
auto
common_indice
=
loop_vars_
.
Map
([](
const
auto
&
iv
)
{
return
iv
->
var
;
});
bool
ParallelOp
::
IsCommonAccessIndice
(
const
Buffer
&
buffer
)
const
{
auto
common_indice
=
loop_vars_
.
Map
([](
const
auto
&
iv
)
{
return
iv
->
var
;
});
return
StructuralEqual
()(
indice_map_
[
buffer
],
common_indice
);
}
LayoutMap
ParallelOp
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
if
(
loop_layout_
.
defined
())
return
{};
if
(
level
==
InferLevel
::
kStrict
)
return
{};
LayoutMap
ParallelOp
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
if
(
loop_layout_
.
defined
())
return
{};
if
(
level
==
InferLevel
::
kStrict
)
return
{};
// Step 1: try to infer loop's partition from a source fragment
Buffer
source_buffer
,
read_source_buffer
;
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
if
(
T
.
layout_map
.
count
(
buffer
))
{
auto
frag
=
T
.
layout_map
[
buffer
].
as
<
Fragment
>
().
value
();
if
(
buffer_is_write_
.
count
(
buffer
))
...
...
@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
read_source_buffer
=
buffer
;
}
}
auto
compute_loop_layout_from_buffer
=
[
&
](
const
Buffer
&
buffer
)
{
auto
compute_loop_layout_from_buffer
=
[
&
](
const
Buffer
&
buffer
)
{
Fragment
src_layout
=
T
.
layout_map
[
buffer
].
as
<
Fragment
>
().
value
();
if
(
IsCommonAccessIndice
(
buffer
))
{
return
src_layout
;
}
else
{
Var
rep
;
auto
rep_iter
=
IterVar
({
0
,
src_layout
->
ReplicateExtent
()},
rep
,
IterVarType
::
kDataPar
);
PrimExpr
loop_var_to_thread
=
src_layout
->
ForwardThread
(
indice_map_
[
buffer
],
rep
);
auto
rep_iter
=
IterVar
({
0
,
src_layout
->
ReplicateExtent
()},
rep
,
IterVarType
::
kDataPar
);
PrimExpr
loop_var_to_thread
=
src_layout
->
ForwardThread
(
indice_map_
[
buffer
],
rep
);
return
Fragment
(
loop_vars_
,
{},
loop_var_to_thread
,
rep_iter
);
}
};
...
...
@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if
(
read_source_buffer
.
defined
())
{
loop_layout_
=
compute_loop_layout_from_buffer
(
read_source_buffer
);
// Loop don't need to be replicated.
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
()))
loop_layout_
=
loop_layout_
->
DeReplicate
();
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
()))
loop_layout_
=
loop_layout_
->
DeReplicate
();
// if still has replication, add a condition
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
()))
{
auto
inv
=
loop_layout_
->
Inverse
();
Array
<
PrimExpr
>
fwd
;
for
(
size_t
i
=
0
;
i
<
loop_layout_
->
OutputDim
();
i
++
)
fwd
.
push_back
(
0
);
for
(
size_t
i
=
0
;
i
<
loop_layout_
->
OutputDim
();
i
++
)
fwd
.
push_back
(
0
);
fwd
.
push_back
(
InputPlaceholder
(
0
));
auto
rep
=
inv
->
Forward
(
fwd
).
back
();
AddPredicate
(
EQ
(
rep
,
0
));
...
...
@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
}
else
{
// Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout
auto
maybe_remapped_root_
=
IfBufferRemapLoopGenerator
::
run
(
root_
,
T
.
buffer_remap
,
T
.
layout_map
);
auto
maybe_remapped_root_
=
IfBufferRemapLoopGenerator
::
run
(
root_
,
T
.
buffer_remap
,
T
.
layout_map
);
int
vector_size
=
GetVectorizeSize
(
maybe_remapped_root_
);
// Check if coalesced_width is defined
if
(
auto
coalesced_width
=
root_
->
annotations
.
Get
(
tl
::
attr
::
coalesced_width
))
{
if
(
const
auto
*
imm
=
coalesced_width
.
as
<
IntImmNode
>
())
{
if
(
auto
coalesced_width
=
root_
->
annotations
.
Get
(
tl
::
attr
::
coalesced_width
))
{
if
(
const
auto
*
imm
=
coalesced_width
.
as
<
IntImmNode
>
())
{
int
expected
=
imm
->
value
;
// Verify that vector_size is divisible by expected
if
(
vector_size
%
expected
!=
0
)
{
LOG
(
FATAL
)
<<
"Vector size "
<<
vector_size
<<
" is not divisible by coalesced width "
<<
expected
;
LOG
(
FATAL
)
<<
"Vector size "
<<
vector_size
<<
" is not divisible by coalesced width "
<<
expected
;
}
vector_size
=
expected
;
}
else
{
...
...
@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
loop_layout_
=
PlanLoopPartition
(
root_
,
T
.
block_size
,
vector_size
);
}
PrimExpr
loop_thread_extent
=
loop_layout_
->
ThreadExtent
();
if
(
!
analyzer_
.
CanProveEqual
(
loop_thread_extent
,
static_cast
<
int
>
(
T
.
block_size
)))
if
(
!
analyzer_
.
CanProveEqual
(
loop_thread_extent
,
static_cast
<
int
>
(
T
.
block_size
)))
AddPredicate
(
LT
(
InputPlaceholder
(
0
),
loop_thread_extent
));
}
else
{
return
{};
}
// Step 2: Check that the loop's partition can correctly align with all source fragment
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
// Step 2: Check that the loop's partition can correctly align with all source
// fragment
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
if
(
T
.
layout_map
.
count
(
buffer
))
{
auto
fragment
=
T
.
layout_map
[
buffer
].
as
<
Fragment
>
().
value
();
// TODO: Add thread checks for replicated cases
// need to wildcard match the rhs with lhs
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
())
||
!
is_one
(
fragment
->
ReplicateExtent
()))
if
(
!
is_one
(
loop_layout_
->
ReplicateExtent
())
||
!
is_one
(
fragment
->
ReplicateExtent
()))
continue
;
auto
vars
=
loop_vars_
.
Map
([](
const
IterVar
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
});
auto
vars
=
loop_vars_
.
Map
([](
const
IterVar
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
});
auto
lhs
=
loop_layout_
->
ForwardThread
(
vars
,
NullOpt
);
auto
rhs
=
fragment
->
ForwardThread
(
indice_map_
[
buffer
],
NullOpt
);
auto
diff
=
analyzer_
.
Simplify
(
lhs
-
rhs
);
ICHECK
(
is_zero
(
diff
))
<<
"Layout infer conflict for "
<<
buffer
<<
" "
<<
source_buffer
<<
"
\n
LHS = "
<<
lhs
<<
"
\n
RHS = "
<<
rhs
;
ICHECK
(
is_zero
(
diff
))
<<
"Layout infer conflict for "
<<
buffer
<<
" "
<<
source_buffer
<<
"
\n
LHS = "
<<
lhs
<<
"
\n
RHS = "
<<
rhs
;
}
}
// Step 3: Infer other fragment's layout from the loop's partition
LayoutMap
results
;
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
if
(
!
T
.
layout_map
.
count
(
buffer
))
results
.
Set
(
buffer
,
CompleteBufferFragment
(
buffer
));
for
(
const
auto
&
[
buffer
,
_
]
:
indice_map_
)
{
if
(
!
T
.
layout_map
.
count
(
buffer
))
results
.
Set
(
buffer
,
CompleteBufferFragment
(
buffer
));
}
return
results
;
}
...
...
@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
}
}
Fragment
ParallelOp
::
CompleteBufferFragment
(
const
Buffer
&
buffer
)
{
Fragment
ParallelOp
::
CompleteBufferFragment
(
const
Buffer
&
buffer
)
{
ICHECK
(
loop_layout_
.
defined
());
if
(
IsCommonAccessIndice
(
buffer
))
return
loop_layout_
;
if
(
IsCommonAccessIndice
(
buffer
))
return
loop_layout_
;
PrimExpr
rep_b
=
MakeFlattenedExpression
(
DivideUnusedIterators
(
indice_map_
[
buffer
],
loop_vars_
,
&
analyzer_
));
PrimExpr
rep_b
=
MakeFlattenedExpression
(
DivideUnusedIterators
(
indice_map_
[
buffer
],
loop_vars_
,
&
analyzer_
));
auto
bijective_indice
=
indice_map_
[
buffer
];
bijective_indice
.
push_back
(
rep_b
);
Layout
ind_inv
=
Layout
(
loop_vars_
,
bijective_indice
)
->
Inverse
();
PrimExpr
indice_rep_extent
=
ind_inv
->
InputShape
().
back
();
// this is the size of rep_b
PrimExpr
indice_rep_extent
=
ind_inv
->
InputShape
().
back
();
// this is the size of rep_b
PrimExpr
loop_rep_extent
=
loop_layout_
->
ReplicateExtent
();
PrimExpr
dest_buffer_rep_extent
=
indice_rep_extent
*
loop_rep_extent
;
...
...
@@ -242,11 +262,12 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
}
fwd
.
push_back
(
FloorMod
(
ReplicationPlaceholder
(),
indice_rep_extent
));
PrimExpr
thd_b
=
loop_layout_
->
ForwardThread
(
ind_inv
->
Forward
(
fwd
),
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
ind_inv
->
Forward
(
fwd
),
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
return
Fragment
(
buffer
->
shape
,
{},
thd_b
,
dest_buffer_rep_extent
,
NullOpt
)
->
CondenseReplicateVar
();
}
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/op/parallel.h
View file @
549416f7
...
...
@@ -23,30 +23,30 @@ using namespace tir;
class
ParallelOp
;
class
ParallelLoopNestVisitor
:
public
StmtExprVisitor
{
private:
ParallelLoopNestVisitor
(
ParallelOp
*
op
)
:
p
(
op
){};
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
;
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
;
private:
ParallelLoopNestVisitor
(
ParallelOp
*
op
)
:
p
(
op
){};
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
;
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
;
ParallelOp
*
p
;
ParallelOp
*
p
;
friend
class
ParallelOp
;
};
class
ParallelOp
:
public
Operator
{
public:
public:
ParallelOp
(
For
root
);
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
Fragment
GetLoopLayout
()
const
{
return
loop_layout_
;
}
For
GetRoot
()
const
{
return
root_
;
}
Map
<
Buffer
,
Array
<
PrimExpr
>>
GetIndiceMap
()
const
{
return
indice_map_
;
}
Optional
<
PrimExpr
>
GetPredicate
(
Var
thread_var
)
const
;
private:
Fragment
CompleteBufferFragment
(
const
Buffer
&
buffer
);
bool
IsCommonAccessIndice
(
const
Buffer
&
buffer
)
const
;
private:
Fragment
CompleteBufferFragment
(
const
Buffer
&
buffer
);
bool
IsCommonAccessIndice
(
const
Buffer
&
buffer
)
const
;
void
AddPredicate
(
PrimExpr
expr
)
{
predicate_
=
predicate_
.
defined
()
?
And
(
expr
,
predicate_
.
value
())
:
expr
;
}
...
...
@@ -66,7 +66,7 @@ class ParallelOp : public Operator {
friend
class
ParallelLoopNestVisitor
;
};
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
#endif
// TVM_TL_OP_PARALLEL_H_
#endif // TVM_TL_OP_PARALLEL_H_
src/op/reduce.cc
View file @
549416f7
...
...
@@ -41,57 +41,58 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
PrimExpr
ReduceOp
::
MakeInitValue
()
const
{
switch
(
type
)
{
case
ReduceType
::
kSum
:
return
make_zero
(
dst
->
dtype
);
case
ReduceType
::
kAbsSum
:
return
make_zero
(
dst
->
dtype
);
case
ReduceType
::
kMax
:
return
make_const
(
dst
->
dtype
,
-
INFINITY
);
case
ReduceType
::
kMin
:
return
make_const
(
dst
->
dtype
,
INFINITY
);
default:
ICHECK
(
0
);
case
ReduceType
::
kSum
:
return
make_zero
(
dst
->
dtype
);
case
ReduceType
::
kAbsSum
:
return
make_zero
(
dst
->
dtype
);
case
ReduceType
::
kMax
:
return
make_const
(
dst
->
dtype
,
-
INFINITY
);
case
ReduceType
::
kMin
:
return
make_const
(
dst
->
dtype
,
INFINITY
);
default:
ICHECK
(
0
);
}
}
PrimExpr
ReduceOp
::
MakeReduce
(
const
PrimExpr
&
a
,
const
PrimExpr
&
b
)
const
{
PrimExpr
ReduceOp
::
MakeReduce
(
const
PrimExpr
&
a
,
const
PrimExpr
&
b
)
const
{
PrimExpr
lhs
=
a
,
rhs
=
b
;
if
(
lhs
->
dtype
!=
rhs
->
dtype
)
{
rhs
=
Cast
(
lhs
->
dtype
,
rhs
);
}
switch
(
type
)
{
case
ReduceType
::
kSum
:
return
lhs
+
rhs
;
case
ReduceType
::
kAbsSum
:
return
lhs
+
Max
(
rhs
,
-
rhs
);
case
ReduceType
::
kMax
:
return
Max
(
lhs
,
rhs
);
case
ReduceType
::
kMin
:
return
Min
(
lhs
,
rhs
);
default:
ICHECK
(
0
);
return
PrimExpr
(
0
);
case
ReduceType
::
kSum
:
return
lhs
+
rhs
;
case
ReduceType
::
kAbsSum
:
return
lhs
+
Max
(
rhs
,
-
rhs
);
case
ReduceType
::
kMax
:
return
Max
(
lhs
,
rhs
);
case
ReduceType
::
kMin
:
return
Min
(
lhs
,
rhs
);
default:
ICHECK
(
0
);
return
PrimExpr
(
0
);
}
}
std
::
string
ReduceOp
::
MakeCodegenReducer
()
const
{
switch
(
type
)
{
case
ReduceType
::
kSum
:
return
"tl::SumOp"
;
case
ReduceType
::
kAbsSum
:
return
"tl::SumOp"
;
case
ReduceType
::
kMax
:
return
"tl::MaxOp"
;
case
ReduceType
::
kMin
:
return
"tl::MinOp"
;
default:
ICHECK
(
0
);
return
""
;
case
ReduceType
::
kSum
:
return
"tl::SumOp"
;
case
ReduceType
::
kAbsSum
:
return
"tl::SumOp"
;
case
ReduceType
::
kMax
:
return
"tl::MaxOp"
;
case
ReduceType
::
kMin
:
return
"tl::MinOp"
;
default:
ICHECK
(
0
);
return
""
;
}
}
Stmt
ReduceOp
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
ICHECK
(
this
->
src
.
scope
()
==
"local.fragment"
&&
this
->
dst
.
scope
()
==
"local.fragment"
)
Stmt
ReduceOp
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
ICHECK
(
this
->
src
.
scope
()
==
"local.fragment"
&&
this
->
dst
.
scope
()
==
"local.fragment"
)
<<
"Reduce for shared memory not implemented."
;
auto
src_buffer
=
T
.
buffer_remap
[
this
->
src
];
auto
dst_buffer
=
T
.
buffer_remap
[
this
->
dst
];
...
...
@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Array
<
IterVar
>
dst_vars
;
for
(
size_t
i
=
0
;
i
<
dst_layout
->
InputDim
();
i
++
)
{
Var
var
=
Var
(
std
::
string
{
char
(
'i'
+
i
)});
dst_vars
.
push_back
(
IterVar
(
Range
(
0
,
dst_layout
->
InputShape
()[
i
]),
var
,
IterVarType
::
kDataPar
));
dst_vars
.
push_back
(
IterVar
(
Range
(
0
,
dst_layout
->
InputShape
()[
i
]),
var
,
IterVarType
::
kDataPar
));
}
Array
<
IterVar
>
src_vars
=
dst_vars
;
src_vars
.
insert
(
src_vars
.
begin
()
+
this
->
dim
,
{
Range
(
0
,
src_layout
->
InputShape
()[
this
->
dim
]),
Var
(
"rv"
),
IterVarType
::
kDataPar
});
Array
<
PrimExpr
>
src_indices
=
src_layout
->
Forward
(
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
Array
<
PrimExpr
>
dst_indices
=
dst_layout
->
Forward
(
dst_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
src_vars
.
insert
(
src_vars
.
begin
()
+
this
->
dim
,
{
Range
(
0
,
src_layout
->
InputShape
()[
this
->
dim
]),
Var
(
"rv"
),
IterVarType
::
kDataPar
});
Array
<
PrimExpr
>
src_indices
=
src_layout
->
Forward
(
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
Array
<
PrimExpr
>
dst_indices
=
dst_layout
->
Forward
(
dst_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
Array
<
Stmt
>
stmts
;
// make reduce-init stmt
if
(
this
->
clear
)
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
this
->
MakeInitValue
(),
dst_indices
));
if
(
this
->
clear
)
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
this
->
MakeInitValue
(),
dst_indices
));
// make thread-local reduce
Array
<
PrimExpr
>
src_indice_compressed
;
...
...
@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
for
(
size_t
i
=
0
;
i
<
src_layout
->
OutputDim
();
i
++
)
{
PrimExpr
expr
;
IterVar
var
;
std
::
tie
(
expr
,
var
)
=
CompressIterator
(
src_indices
[
i
],
src_vars
,
src_vars
[
this
->
dim
]
->
var
,
analyzer
);
std
::
tie
(
expr
,
var
)
=
CompressIterator
(
src_indices
[
i
],
src_vars
,
src_vars
[
this
->
dim
]
->
var
,
analyzer
);
src_indice_compressed
.
push_back
(
expr
);
src_var_compressed
.
push_back
(
var
);
}
Stmt
reduce_local
=
BufferStore
(
dst_buffer
,
this
->
MakeReduce
(
BufferLoad
(
dst_buffer
,
dst_indices
),
BufferLoad
(
src_buffer
,
src_indice_compressed
)),
dst_indices
);
Stmt
reduce_local
=
BufferStore
(
dst_buffer
,
this
->
MakeReduce
(
BufferLoad
(
dst_buffer
,
dst_indices
),
BufferLoad
(
src_buffer
,
src_indice_compressed
)),
dst_indices
);
for
(
int
i
=
src_layout
->
OutputDim
()
-
1
;
i
>=
0
;
i
--
)
{
reduce_local
=
For
(
src_var_compressed
[
i
]
->
var
,
0
,
src_var_compressed
[
i
]
->
dom
->
extent
,
ForKind
::
kUnrolled
,
reduce_local
,
NullOpt
,
{{
tir
::
attr
::
pragma_unroll_explicit
,
Bool
(
false
)}});
For
(
src_var_compressed
[
i
]
->
var
,
0
,
src_var_compressed
[
i
]
->
dom
->
extent
,
ForKind
::
kUnrolled
,
reduce_local
,
NullOpt
,
{{
tir
::
attr
::
pragma_unroll_explicit
,
Bool
(
false
)}});
}
stmts
.
push_back
(
reduce_local
);
// make inter-thread reduce
PrimExpr
src_thread
=
src_layout
->
ForwardThread
(
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}),
{});
auto
iter_sum
=
arith
::
NormalizeToIterSum
(
src_thread
,
ToVMap
(
src_vars
),
analyzer
);
for
(
const
auto
&
iter_split
:
iter_sum
->
args
)
{
PrimExpr
src_thread
=
src_layout
->
ForwardThread
(
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}),
{});
auto
iter_sum
=
arith
::
NormalizeToIterSum
(
src_thread
,
ToVMap
(
src_vars
),
analyzer
);
for
(
const
auto
&
iter_split
:
iter_sum
->
args
)
{
auto
mark
=
iter_split
->
source
->
source
.
as
<
Var
>
();
ICHECK
(
mark
.
defined
());
if
(
mark
.
value
().
same_as
(
src_vars
[
this
->
dim
]
->
var
))
{
auto
scale
=
as_const_int
(
iter_split
->
scale
);
auto
extent
=
as_const_int
(
iter_split
->
extent
);
ICHECK
(
scale
!=
nullptr
&&
extent
!=
nullptr
);
if
(
*
extent
==
1
)
continue
;
if
(
*
extent
==
1
)
continue
;
int
reducing_threads
=
(
*
extent
)
*
(
*
scale
);
std
::
stringstream
ss
;
ss
<<
"tl::AllReduce<"
<<
this
->
MakeCodegenReducer
()
<<
", "
<<
reducing_threads
<<
", "
<<
(
*
scale
)
<<
">::run"
;
Array
<
PrimExpr
>
thread_reduce_args
=
{
StringImm
(
ss
.
str
()),
BufferLoad
(
dst_buffer
,
dst_indices
)};
ss
<<
"tl::AllReduce<"
<<
this
->
MakeCodegenReducer
()
<<
", "
<<
reducing_threads
<<
", "
<<
(
*
scale
)
<<
">::run"
;
Array
<
PrimExpr
>
thread_reduce_args
=
{
StringImm
(
ss
.
str
()),
BufferLoad
(
dst_buffer
,
dst_indices
)};
if
(
reducing_threads
>=
32
)
{
PrimExpr
workspace
=
T
.
AddWorkspace
(
T
.
block_size
,
dst_buffer
->
dtype
);
thread_reduce_args
.
push_back
(
workspace
);
}
auto
call
=
Call
(
dst_buffer
->
dtype
,
builtin
::
call_extern
(),
thread_reduce_args
);
auto
call
=
Call
(
dst_buffer
->
dtype
,
builtin
::
call_extern
(),
thread_reduce_args
);
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
call
,
dst_indices
));
}
}
...
...
@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
// make the outer spatial loop
Stmt
body
=
stmts
.
size
()
>
1
?
SeqStmt
(
stmts
)
:
stmts
[
0
];
for
(
int
i
=
dst_layout
->
InputDim
()
-
1
;
i
>=
0
;
i
--
)
{
body
=
For
(
dst_vars
[
i
]
->
var
,
0
,
dst_vars
[
i
]
->
dom
->
extent
,
ForKind
::
kParallel
,
body
);
body
=
For
(
dst_vars
[
i
]
->
var
,
0
,
dst_vars
[
i
]
->
dom
->
extent
,
ForKind
::
kParallel
,
body
);
}
body
=
PartitionLoop
(
Downcast
<
For
>
(
body
),
T
.
thread_var
,
analyzer
,
dst_layout
);
return
body
;
}
LayoutMap
ReduceOp
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
if
(
level
>=
InferLevel
::
kStrict
)
return
{};
LayoutMap
ReduceOp
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
{
if
(
level
>=
InferLevel
::
kStrict
)
return
{};
if
(
src
.
scope
()
==
"local.fragment"
&&
dst
.
scope
()
==
"local.fragment"
&&
T
.
layout_map
.
count
(
src
)
&&
!
T
.
layout_map
.
count
(
dst
))
{
auto
src_layout
=
T
.
layout_map
[
src
].
as
<
Fragment
>
().
value
();
...
...
@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
fwd
.
push_back
(
InputPlaceholder
(
i
-
1
));
}
}
auto
thd
=
src_layout
->
ForwardThread
(
fwd
,
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
auto
thd
=
src_layout
->
ForwardThread
(
fwd
,
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
Fragment
dst_layout
=
Fragment
(
dst
->
shape
,
{},
thd
,
dest_buffer_rep_extent
,
NullOpt
)
->
CondenseReplicateVar
();
Fragment
(
dst
->
shape
,
{},
thd
,
dest_buffer_rep_extent
,
NullOpt
)
->
CondenseReplicateVar
();
return
{{
dst
,
dst_layout
}};
}
return
{};
...
...
@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
TIR_REGISTER_TL_OP
(
ReduceOp
,
reduce
)
.
set_num_inputs
(
4
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
}
// namespace tl
}
// namespace tvm
\ No newline at end of file
}
// namespace tl
}
// namespace tvm
\ No newline at end of file
src/op/reduce.h
View file @
549416f7
...
...
@@ -18,13 +18,13 @@ namespace tl {
using
namespace
tir
;
class
ReduceOp
:
public
Operator
{
public:
public:
ReduceOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
final
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
static
const
Op
&
Get
();
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
final
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
final
;
static
const
Op
&
Get
();
private:
private:
tir
::
Buffer
src
,
dst
;
int
dim
;
enum
class
ReduceType
{
...
...
@@ -36,11 +36,11 @@ class ReduceOp : public Operator {
bool
clear
;
PrimExpr
MakeInitValue
()
const
;
PrimExpr
MakeReduce
(
const
PrimExpr
&
a
,
const
PrimExpr
&
b
)
const
;
PrimExpr
MakeReduce
(
const
PrimExpr
&
a
,
const
PrimExpr
&
b
)
const
;
std
::
string
MakeCodegenReducer
()
const
;
};
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
#endif // TVM_TL_OP_REDUCE_H_
\ No newline at end of file
#endif // TVM_TL_OP_REDUCE_H_
\ No newline at end of file
src/runtime/runtime.cc
View file @
549416f7
...
...
@@ -17,12 +17,12 @@ namespace tl {
using
namespace
runtime
;
template
<
typename
T
>
static
std
::
string
ArrayToStr
(
const
T
*
ptr
,
size_t
n
)
{
template
<
typename
T
>
static
std
::
string
ArrayToStr
(
const
T
*
ptr
,
size_t
n
)
{
std
::
stringstream
ss
;
ss
<<
"["
;
for
(
size_t
i
=
0
;
i
<
n
;
i
++
)
{
if
(
i
>
0
)
ss
<<
", "
;
if
(
i
>
0
)
ss
<<
", "
;
ss
<<
ptr
[
i
];
}
ss
<<
"]"
;
...
...
@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) {
}
struct
TensorMapArgs
{
CUtensorMap
*
map
;
CUtensorMap
*
map
;
CUtensorMapDataType
type
;
cuuint32_t
tensorRank
;
void
*
globalAddress
;
void
*
globalAddress
;
cuuint64_t
globalDim
[
5
],
globalStride
[
5
];
cuuint32_t
boxDim
[
5
],
elementStrides
[
5
];
CUtensorMapInterleave
interleave
;
...
...
@@ -45,8 +45,9 @@ struct TensorMapArgs {
TensorMapArgs
T
;
int
idx
=
0
;
ICHECK
(
args
.
num_args
>=
8
);
T
.
map
=
reinterpret_cast
<
CUtensorMap
*>
(
static_cast
<
void
*>
(
args
[
idx
++
]));
T
.
type
=
static_cast
<
CUtensorMapDataType
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
map
=
reinterpret_cast
<
CUtensorMap
*>
(
static_cast
<
void
*>
(
args
[
idx
++
]));
T
.
type
=
static_cast
<
CUtensorMapDataType
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
tensorRank
=
static_cast
<
cuuint32_t
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
globalAddress
=
args
[
idx
++
];
ICHECK
(
T
.
tensorRank
>=
1
&&
T
.
tensorRank
<=
5
);
...
...
@@ -63,10 +64,14 @@ struct TensorMapArgs {
for
(
size_t
i
=
0
;
i
<
T
.
tensorRank
;
i
++
)
{
T
.
elementStrides
[
i
]
=
static_cast
<
cuuint64_t
>
(
args
[
idx
++
]);
}
T
.
interleave
=
static_cast
<
CUtensorMapInterleave
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
swizzle
=
static_cast
<
CUtensorMapSwizzle
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
l2Promotion
=
static_cast
<
CUtensorMapL2promotion
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
oobFill
=
static_cast
<
CUtensorMapFloatOOBfill
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
interleave
=
static_cast
<
CUtensorMapInterleave
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
swizzle
=
static_cast
<
CUtensorMapSwizzle
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
l2Promotion
=
static_cast
<
CUtensorMapL2promotion
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
oobFill
=
static_cast
<
CUtensorMapFloatOOBfill
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
return
T
;
}
...
...
@@ -79,7 +84,8 @@ struct TensorMapArgs {
<<
"globalDim "
<<
ArrayToStr
(
globalDim
,
tensorRank
)
<<
std
::
endl
<<
"globalStrides "
<<
ArrayToStr
(
globalStride
,
tensorRank
)
<<
std
::
endl
<<
"boxDim "
<<
ArrayToStr
(
boxDim
,
tensorRank
)
<<
std
::
endl
<<
"elementStrides "
<<
ArrayToStr
(
elementStrides
,
tensorRank
)
<<
std
::
endl
<<
"elementStrides "
<<
ArrayToStr
(
elementStrides
,
tensorRank
)
<<
std
::
endl
<<
"interleave "
<<
interleave
<<
std
::
endl
<<
"swizzle "
<<
swizzle
<<
std
::
endl
<<
"l2Promotion "
<<
l2Promotion
<<
std
::
endl
...
...
@@ -89,23 +95,26 @@ struct TensorMapArgs {
};
// set device api
TVM_REGISTER_GLOBAL
(
tvm_tensormap_create_tiled
).
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TensorMapArgs
T
=
TensorMapArgs
::
Extract
(
args
);
CUresult
result
=
cuTensorMapEncodeTiled
(
T
.
map
,
T
.
type
,
T
.
tensorRank
,
T
.
globalAddress
,
T
.
globalDim
,
T
.
globalStride
+
1
,
T
.
boxDim
,
T
.
elementStrides
,
T
.
interleave
,
T
.
swizzle
,
T
.
l2Promotion
,
T
.
oobFill
);
if
(
result
!=
CUDA_SUCCESS
)
{
LOG_FATAL
<<
"Failed to initialize the TMA descriptor "
<<
result
<<
std
::
endl
<<
T
.
ToDebugString
();
}
*
ret
=
static_cast
<
int
>
(
result
);
});
TVM_REGISTER_GLOBAL
(
tvm_tensormap_create_tiled
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TensorMapArgs
T
=
TensorMapArgs
::
Extract
(
args
);
CUresult
result
=
cuTensorMapEncodeTiled
(
T
.
map
,
T
.
type
,
T
.
tensorRank
,
T
.
globalAddress
,
T
.
globalDim
,
T
.
globalStride
+
1
,
T
.
boxDim
,
T
.
elementStrides
,
T
.
interleave
,
T
.
swizzle
,
T
.
l2Promotion
,
T
.
oobFill
);
if
(
result
!=
CUDA_SUCCESS
)
{
LOG_FATAL
<<
"Failed to initialize the TMA descriptor "
<<
result
<<
std
::
endl
<<
T
.
ToDebugString
();
}
*
ret
=
static_cast
<
int
>
(
result
);
});
struct
TensorMapIm2ColArgs
{
CUtensorMap
*
map
;
CUtensorMap
*
map
;
CUtensorMapDataType
type
;
cuuint32_t
tensorRank
;
void
*
globalAddress
;
void
*
globalAddress
;
cuuint64_t
globalDim
[
5
],
globalStride
[
5
];
cuuint32_t
elementStrides
[
5
];
int
pixelBoxLowerCorner
[
3
],
pixelBoxUpperCorner
[
3
];
...
...
@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs {
TensorMapIm2ColArgs
T
;
int
idx
=
0
;
ICHECK
(
args
.
num_args
>=
8
);
T
.
map
=
reinterpret_cast
<
CUtensorMap
*>
(
static_cast
<
void
*>
(
args
[
idx
++
]));
T
.
type
=
static_cast
<
CUtensorMapDataType
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
map
=
reinterpret_cast
<
CUtensorMap
*>
(
static_cast
<
void
*>
(
args
[
idx
++
]));
T
.
type
=
static_cast
<
CUtensorMapDataType
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
tensorRank
=
static_cast
<
cuuint32_t
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
globalAddress
=
args
[
idx
++
];
ICHECK
(
T
.
tensorRank
>=
3
&&
T
.
tensorRank
<=
5
);
...
...
@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs {
}
T
.
smem_box_pixel
=
static_cast
<
cuuint64_t
>
(
args
[
idx
++
]);
T
.
smem_box_channel
=
static_cast
<
cuuint64_t
>
(
args
[
idx
++
]);
T
.
interleave
=
static_cast
<
CUtensorMapInterleave
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
swizzle
=
static_cast
<
CUtensorMapSwizzle
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
l2Promotion
=
static_cast
<
CUtensorMapL2promotion
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
oobFill
=
static_cast
<
CUtensorMapFloatOOBfill
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
interleave
=
static_cast
<
CUtensorMapInterleave
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
swizzle
=
static_cast
<
CUtensorMapSwizzle
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
l2Promotion
=
static_cast
<
CUtensorMapL2promotion
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
T
.
oobFill
=
static_cast
<
CUtensorMapFloatOOBfill
>
(
static_cast
<
int64_t
>
(
args
[
idx
++
]));
return
T
;
}
...
...
@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs {
<<
"globalStrides "
<<
ArrayToStr
(
globalStride
,
tensorRank
)
<<
std
::
endl
<<
"smem_box_pixel "
<<
smem_box_pixel
<<
std
::
endl
<<
"smem_box_channel "
<<
smem_box_channel
<<
std
::
endl
<<
"pixelBoxLowerCorner "
<<
ArrayToStr
(
pixelBoxLowerCorner
,
tensorRank
-
2
)
<<
std
::
endl
<<
"pixelBoxUpperCorner "
<<
ArrayToStr
(
pixelBoxUpperCorner
,
tensorRank
-
2
)
<<
std
::
endl
<<
"elementStrides "
<<
ArrayToStr
(
elementStrides
,
tensorRank
)
<<
std
::
endl
<<
"pixelBoxLowerCorner "
<<
ArrayToStr
(
pixelBoxLowerCorner
,
tensorRank
-
2
)
<<
std
::
endl
<<
"pixelBoxUpperCorner "
<<
ArrayToStr
(
pixelBoxUpperCorner
,
tensorRank
-
2
)
<<
std
::
endl
<<
"elementStrides "
<<
ArrayToStr
(
elementStrides
,
tensorRank
)
<<
std
::
endl
<<
"interleave "
<<
interleave
<<
std
::
endl
<<
"swizzle "
<<
swizzle
<<
std
::
endl
<<
"l2Promotion "
<<
l2Promotion
<<
std
::
endl
...
...
@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs {
}
};
TVM_REGISTER_GLOBAL
(
tvm_tensormap_create_im2col
).
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TensorMapIm2ColArgs
T
=
TensorMapIm2ColArgs
::
Extract
(
args
);
CUresult
result
=
cuTensorMapEncodeIm2col
(
T
.
map
,
T
.
type
,
T
.
tensorRank
,
T
.
globalAddress
,
T
.
globalDim
,
T
.
globalStride
+
1
,
T
.
pixelBoxLowerCorner
,
T
.
pixelBoxUpperCorner
,
T
.
smem_box_channel
,
T
.
smem_box_pixel
,
T
.
elementStrides
,
T
.
interleave
,
T
.
swizzle
,
T
.
l2Promotion
,
T
.
oobFill
);
if
(
result
!=
CUDA_SUCCESS
)
{
LOG_FATAL
<<
"Failed to initialize the TMA descriptor "
<<
result
<<
std
::
endl
<<
T
.
ToDebugString
();
}
*
ret
=
static_cast
<
int
>
(
result
);
});
TVM_REGISTER_GLOBAL
(
tvm_tensormap_create_im2col
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TensorMapIm2ColArgs
T
=
TensorMapIm2ColArgs
::
Extract
(
args
);
CUresult
result
=
cuTensorMapEncodeIm2col
(
T
.
map
,
T
.
type
,
T
.
tensorRank
,
T
.
globalAddress
,
T
.
globalDim
,
T
.
globalStride
+
1
,
T
.
pixelBoxLowerCorner
,
T
.
pixelBoxUpperCorner
,
T
.
smem_box_channel
,
T
.
smem_box_pixel
,
T
.
elementStrides
,
T
.
interleave
,
T
.
swizzle
,
T
.
l2Promotion
,
T
.
oobFill
);
if
(
result
!=
CUDA_SUCCESS
)
{
LOG_FATAL
<<
"Failed to initialize the TMA descriptor "
<<
result
<<
std
::
endl
<<
T
.
ToDebugString
();
}
*
ret
=
static_cast
<
int
>
(
result
);
});
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/runtime/runtime.h
View file @
549416f7
...
...
@@ -13,9 +13,11 @@
namespace
tvm
{
namespace
tl
{
constexpr
const
char
*
tvm_tensormap_create_tiled
=
"__tvm_tensormap_create_tiled"
;
constexpr
const
char
*
tvm_tensormap_create_im2col
=
"__tvm_tensormap_create_im2col"
;
}
// namespace tl
}
// namespace tvm
constexpr
const
char
*
tvm_tensormap_create_tiled
=
"__tvm_tensormap_create_tiled"
;
constexpr
const
char
*
tvm_tensormap_create_im2col
=
"__tvm_tensormap_create_im2col"
;
}
// namespace tl
}
// namespace tvm
#endif // TVM_TL_RUNTIME_RUNTIME_H_
\ No newline at end of file
#endif // TVM_TL_RUNTIME_RUNTIME_H_
\ No newline at end of file
src/target/codegen_cuda.cc
View file @
549416f7
...
...
@@ -6,9 +6,9 @@
*/
#include "codegen_cuda.h"
#include <tvm/tir/index_map.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <cmath>
...
...
@@ -23,41 +23,51 @@
namespace
tvm
{
namespace
codegen
{
CodeGenTileLangCUDA
::
CodeGenTileLangCUDA
()
{
restrict_keyword_
=
"__restrict__"
;
}
CodeGenTileLangCUDA
::
CodeGenTileLangCUDA
()
{
restrict_keyword_
=
"__restrict__"
;
}
void
CodeGenTileLangCUDA
::
PrintFuncPrefix
(
std
::
ostream
&
os
)
{
os
<<
"extern
\"
C
\"
__global__ "
;
}
void
CodeGenTileLangCUDA
::
PrintFuncPrefix
(
std
::
ostream
&
os
)
{
os
<<
"extern
\"
C
\"
__global__ "
;
}
class
LaunchConfigExtractor
:
public
tir
::
StmtVisitor
{
private:
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
private:
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
if
(
iv
->
var
->
name_hint
==
"threadIdx.x"
||
iv
->
thread_tag
==
"threadIdx.x"
)
{
if
(
iv
->
var
->
name_hint
==
"threadIdx.x"
||
iv
->
thread_tag
==
"threadIdx.x"
)
{
threadIdx_x_ext
=
op
->
value
;
}
else
if
(
iv
->
var
->
name_hint
==
"threadIdx.y"
||
iv
->
thread_tag
==
"threadIdx.y"
)
{
}
else
if
(
iv
->
var
->
name_hint
==
"threadIdx.y"
||
iv
->
thread_tag
==
"threadIdx.y"
)
{
threadIdx_y_ext
=
op
->
value
;
}
else
if
(
iv
->
var
->
name_hint
==
"threadIdx.z"
||
iv
->
thread_tag
==
"threadIdx.z"
)
{
}
else
if
(
iv
->
var
->
name_hint
==
"threadIdx.z"
||
iv
->
thread_tag
==
"threadIdx.z"
)
{
threadIdx_z_ext
=
op
->
value
;
}
}
StmtVisitor
::
VisitStmt_
(
op
);
}
public:
public:
PrimExpr
threadIdx_x_ext
=
Integer
(
1
);
PrimExpr
threadIdx_y_ext
=
Integer
(
1
);
PrimExpr
threadIdx_z_ext
=
Integer
(
1
);
};
void
CodeGenTileLangCUDA
::
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangCUDA
::
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
{
LaunchConfigExtractor
extractor
;
extractor
(
f
->
body
);
arith
::
Analyzer
analyzer
;
PrimExpr
threadIdx_ext
=
analyzer
.
Simplify
(
extractor
.
threadIdx_x_ext
*
extractor
.
threadIdx_y_ext
*
extractor
.
threadIdx_z_ext
);
if
(
const
IntImmNode
*
const
threadIdx_ext_int
=
threadIdx_ext
.
as
<
IntImmNode
>
())
{
PrimExpr
threadIdx_ext
=
analyzer
.
Simplify
(
extractor
.
threadIdx_x_ext
*
extractor
.
threadIdx_y_ext
*
extractor
.
threadIdx_z_ext
);
if
(
const
IntImmNode
*
const
threadIdx_ext_int
=
threadIdx_ext
.
as
<
IntImmNode
>
())
{
if
(
threadIdx_ext_int
->
value
==
1
)
{
// unable to extract the number of threads per block, hence directly return
// unable to extract the number of threads per block, hence directly
// return
return
;
}
stream
<<
" __launch_bounds__("
<<
threadIdx_ext_int
->
value
<<
")"
;
...
...
@@ -77,19 +87,20 @@ std::string CodeGenTileLangCUDA::Finish() {
return
CodeGenC
::
Finish
();
}
void
CodeGenTileLangCUDA
::
VisitStmt_
(
const
tir
::
ForNode
*
op
)
{
void
CodeGenTileLangCUDA
::
VisitStmt_
(
const
tir
::
ForNode
*
op
)
{
if
(
op
->
kind
==
tir
::
ForKind
::
kUnrolled
)
{
PrintIndent
();
stream
<<
"#pragma unroll
\n
"
;
}
std
::
string
extent
=
PrintExpr
(
arith
::
Analyzer
().
Simplify
(
op
->
extent
+
op
->
min
));
std
::
string
extent
=
PrintExpr
(
arith
::
Analyzer
().
Simplify
(
op
->
extent
+
op
->
min
));
PrintIndent
();
std
::
string
vid
=
AllocVarID
(
op
->
loop_var
.
get
());
std
::
string
start
=
PrintExpr
(
op
->
min
);
stream
<<
"for ("
;
PrintType
(
op
->
loop_var
.
dtype
(),
stream
);
stream
<<
' '
<<
vid
<<
" = "
<<
start
<<
"; "
<<
vid
<<
" < "
<<
extent
<<
"; ++"
<<
vid
<<
") {
\n
"
;
stream
<<
' '
<<
vid
<<
" = "
<<
start
<<
"; "
<<
vid
<<
" < "
<<
extent
<<
"; ++"
<<
vid
<<
") {
\n
"
;
int
for_scope
=
BeginScope
();
PrintStmt
(
op
->
body
);
this
->
EndScope
(
for_scope
);
...
...
@@ -97,12 +108,13 @@ void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode* op) {
stream
<<
"}
\n
"
;
}
void
CodeGenTileLangCUDA
::
BindThreadIndex
(
const
IterVar
&
iv
)
{
void
CodeGenTileLangCUDA
::
BindThreadIndex
(
const
IterVar
&
iv
)
{
ICHECK
(
!
var_idmap_
.
count
(
iv
->
var
.
get
()));
var_idmap_
[
iv
->
var
.
get
()]
=
CastFromTo
(
iv
->
thread_tag
,
DataType
::
UInt
(
32
),
iv
->
var
.
dtype
());
var_idmap_
[
iv
->
var
.
get
()]
=
CastFromTo
(
iv
->
thread_tag
,
DataType
::
UInt
(
32
),
iv
->
var
.
dtype
());
}
void
CodeGenTileLangCUDA
::
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangCUDA
::
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
{
// NOLINT(*)
int
lanes
=
t
.
lanes
();
if
(
t
.
is_handle
())
{
ICHECK
(
t
.
is_scalar
())
<<
"do not yet support vector types"
;
...
...
@@ -123,51 +135,54 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
bool
fail
=
false
;
if
(
t
.
is_float
())
{
switch
(
t
.
bits
())
{
case
16
:
if
(
t
.
is_scalar
())
{
os
<<
"half_t"
;
}
else
if
(
lanes
<=
8
)
{
// Emit CUDA code to access fp16 vector elements.
//
// half4 is stored as uint2
//
// h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for half type"
;
os
<<
"uint"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
}
break
;
case
32
:
if
(
lanes
<=
4
)
{
os
<<
"float"
;
}
else
if
(
lanes
<=
8
)
{
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for float type with lanes > 4"
;
os
<<
"ulonglong"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
}
break
;
case
64
:
os
<<
"double"
;
break
;
default:
case
16
:
if
(
t
.
is_scalar
())
{
os
<<
"half_t"
;
}
else
if
(
lanes
<=
8
)
{
// Emit CUDA code to access fp16 vector elements.
//
// half4 is stored as uint2
//
// h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for half type"
;
os
<<
"uint"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
break
;
}
break
;
case
32
:
if
(
lanes
<=
4
)
{
os
<<
"float"
;
}
else
if
(
lanes
<=
8
)
{
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for float type with lanes > 4"
;
os
<<
"ulonglong"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
}
break
;
case
64
:
os
<<
"double"
;
break
;
default:
fail
=
true
;
break
;
}
if
(
!
fail
&&
(
t
.
is_scalar
()
||
t
.
bits
()
==
16
))
return
;
if
(
!
fail
&&
(
lanes
>
4
&&
lanes
<=
8
&&
t
.
bits
()
==
32
))
return
;
if
(
!
fail
&&
(
t
.
is_scalar
()
||
t
.
bits
()
==
16
))
return
;
if
(
!
fail
&&
(
lanes
>
4
&&
lanes
<=
8
&&
t
.
bits
()
==
32
))
return
;
if
(
!
fail
&&
(
lanes
>=
2
&&
lanes
<=
4
))
{
os
<<
lanes
;
return
;
...
...
@@ -181,18 +196,21 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
}
else
{
fail
=
true
;
}
if
(
!
fail
)
return
;
if
(
!
fail
)
return
;
}
else
if
(
t
.
is_float8
())
{
if
(
t
.
is_scalar
())
{
os
<<
"unsigned char"
;
// __nv_fp8_storage_t is an alias of unsigned char
os
<<
"unsigned char"
;
// __nv_fp8_storage_t is an alias of unsigned char
}
else
if
(
lanes
==
2
)
{
os
<<
"unsigned short int"
;
// __nv_fp8x2_storage_t is an alias of unsigned short
os
<<
"unsigned short int"
;
// __nv_fp8x2_storage_t is an alias of
// unsigned short
}
else
if
(
lanes
==
4
)
{
os
<<
"unsigned int"
;
// __nv_fp8x4_storage_t is an alias of unsigned int
os
<<
"unsigned int"
;
// __nv_fp8x4_storage_t is an alias of unsigned int
}
else
{
fail
=
true
;
}
if
(
!
fail
)
return
;
if
(
!
fail
)
return
;
}
else
if
(
t
==
DataType
::
Bool
())
{
os
<<
"bool"
;
return
;
...
...
@@ -209,133 +227,135 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
os
<<
"u"
;
}
switch
(
t
.
bits
())
{
case
1
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
os
<<
"int8_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int16_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
32
)
{
os
<<
"int"
;
return
;
}
else
{
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type!"
;
}
}
case
4
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
4
)
{
os
<<
"int16_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
// directly 8 4-bit int in integer.
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int2"
;
return
;
}
else
if
(
t
.
lanes
()
==
32
)
{
os
<<
"int4"
;
return
;
}
else
if
(
t
.
lanes
()
==
64
)
{
os
<<
"int8"
;
return
;
}
else
{
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type!"
;
}
case
1
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
os
<<
"int8_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int16_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
32
)
{
os
<<
"int"
;
return
;
}
else
{
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type!"
;
}
case
8
:
{
if
(
t
.
lanes
()
==
4
)
{
// directly 4 8 bit int in integer.
// We use int for int8x4 instead of char4 because using char4 is
// likely to produce extra instructions to pack four int8 elements
// into 32-bit data.
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
os
<<
"int2"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int4"
;
return
;
}
else
if
(
!
t
.
is_uint
()
&&
t
.
is_scalar
())
{
os
<<
"signed char"
;
break
;
}
else
{
os
<<
"char"
;
break
;
}
}
case
4
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
4
)
{
os
<<
"int16_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
// directly 8 4-bit int in integer.
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int2"
;
return
;
}
else
if
(
t
.
lanes
()
==
32
)
{
os
<<
"int4"
;
return
;
}
else
if
(
t
.
lanes
()
==
64
)
{
os
<<
"int8"
;
return
;
}
else
{
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type!"
;
}
case
16
:
{
if
(
t
.
is_scalar
())
{
os
<<
"short"
;
}
else
if
(
t
.
lanes
()
<=
4
)
{
os
<<
"short"
<<
lanes
;
}
else
if
(
t
.
lanes
()
<=
8
)
{
// Emit CUDA code to access int16 vector elements.
//
// short4 is stored as int2
//
// s4.x is emitted as *(short2*)(&(i2.x)).x
// s4.y is emitted as *(short2*)(&(i2.x)).y
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ
(
t
.
lanes
()
%
2
,
0
)
<<
"only support even lane for shorT type with lanes > 4"
;
os
<<
"int"
<<
t
.
lanes
()
/
2
;
}
else
{
fail
=
true
;
}
if
(
!
fail
)
{
return
;
}
}
case
8
:
{
if
(
t
.
lanes
()
==
4
)
{
// directly 4 8 bit int in integer.
// We use int for int8x4 instead of char4 because using char4 is
// likely to produce extra instructions to pack four int8 elements
// into 32-bit data.
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
os
<<
"int2"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int4"
;
return
;
}
else
if
(
!
t
.
is_uint
()
&&
t
.
is_scalar
())
{
os
<<
"signed char"
;
break
;
}
case
32
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
}
else
if
(
t
.
lanes
()
<=
4
)
{
os
<<
"int"
<<
t
.
lanes
();
}
else
if
(
t
.
lanes
()
<=
8
)
{
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for int32 type with lanes > 4"
;
os
<<
"longlong"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
}
if
(
!
fail
)
{
return
;
}
}
else
{
os
<<
"char"
;
break
;
}
case
64
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int64_t"
;
}
else
if
(
t
.
lanes
()
==
2
)
{
os
<<
"longlong2"
;
}
else
if
(
t
.
lanes
()
==
3
)
{
os
<<
"longlong3"
;
}
else
if
(
t
.
lanes
()
==
4
)
{
os
<<
"longlong4"
;
}
}
case
16
:
{
if
(
t
.
is_scalar
())
{
os
<<
"short"
;
}
else
if
(
t
.
lanes
()
<=
4
)
{
os
<<
"short"
<<
lanes
;
}
else
if
(
t
.
lanes
()
<=
8
)
{
// Emit CUDA code to access int16 vector elements.
//
// short4 is stored as int2
//
// s4.x is emitted as *(short2*)(&(i2.x)).x
// s4.y is emitted as *(short2*)(&(i2.x)).y
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ
(
t
.
lanes
()
%
2
,
0
)
<<
"only support even lane for shorT type with lanes > 4"
;
os
<<
"int"
<<
t
.
lanes
()
/
2
;
}
else
{
fail
=
true
;
}
if
(
!
fail
)
{
return
;
}
default:
break
;
}
case
32
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
}
else
if
(
t
.
lanes
()
<=
4
)
{
os
<<
"int"
<<
t
.
lanes
();
}
else
if
(
t
.
lanes
()
<=
8
)
{
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for int32 type with lanes > 4"
;
os
<<
"longlong"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
break
;
}
if
(
!
fail
)
{
return
;
}
break
;
}
case
64
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int64_t"
;
}
else
if
(
t
.
lanes
()
==
2
)
{
os
<<
"longlong2"
;
}
else
if
(
t
.
lanes
()
==
3
)
{
os
<<
"longlong3"
;
}
else
if
(
t
.
lanes
()
==
4
)
{
os
<<
"longlong4"
;
}
return
;
}
default:
fail
=
true
;
break
;
}
if
(
!
fail
&&
lanes
==
1
)
{
return
;
...
...
@@ -348,8 +368,9 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type"
;
}
void
CodeGenTileLangCUDA
::
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
PrimExpr
rhs
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangCUDA
::
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
PrimExpr
rhs
,
std
::
ostream
&
os
)
{
// NOLINT(*)
// Declare the result.
std
::
string
sret
=
name_supply_
->
FreshName
(
"_"
);
this
->
PrintIndent
();
...
...
@@ -383,15 +404,18 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string& op, DataType t, Pr
os
<<
sret
;
}
void
CodeGenTileLangCUDA
::
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangCUDA
::
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
ostream
&
os
)
{
// NOLINT(*)
if
(
t
.
is_scalar
())
{
os
<<
vec
;
return
;
}
static
const
char
access
[]
=
{
'x'
,
'y'
,
'z'
,
'w'
};
ICHECK
(
i
>=
0
&&
i
<
(
t
.
bits
()
==
8
?
16
:
(
t
.
bits
()
==
16
||
t
.
bits
()
==
32
)
?
8
:
4
));
ICHECK
(
i
>=
0
&&
i
<
(
t
.
bits
()
==
8
?
16
:
(
t
.
bits
()
==
16
||
t
.
bits
()
==
32
)
?
8
:
4
));
if
(
t
.
bits
()
==
8
&&
(
t
.
is_int
()
||
t
.
is_uint
()))
{
std
::
string
type_name
=
t
.
is_int
()
?
"char"
:
"unsigned char"
;
if
(
t
.
lanes
()
==
2
||
t
.
lanes
()
==
3
)
{
...
...
@@ -401,9 +425,11 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, i
os
<<
"(("
<<
type_name
<<
")("
<<
ac
<<
" >> "
<<
i
%
4
*
8
<<
"))"
;
}
}
else
if
(
t
.
is_float16
())
{
os
<<
"((half2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
os
<<
"((half2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
}
else
if
(
t
.
is_bfloat16
())
{
os
<<
"((nv_bfloat162*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
os
<<
"((nv_bfloat162*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
}
else
if
(
t
.
lanes
()
>
4
&&
t
.
lanes
()
<=
8
)
{
std
::
string
type_name
;
if
(
t
.
bits
()
==
16
)
{
...
...
@@ -422,20 +448,24 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, i
}
}
ICHECK
(
!
type_name
.
empty
());
os
<<
"(("
<<
type_name
<<
"2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
os
<<
"(("
<<
type_name
<<
"2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
}
else
{
os
<<
vec
<<
"."
<<
access
[
i
];
}
}
void
CodeGenTileLangCUDA
::
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
const
std
::
string
&
value
)
{
void
CodeGenTileLangCUDA
::
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
const
std
::
string
&
value
)
{
this
->
PrintIndent
();
static
const
char
access
[]
=
{
'x'
,
'y'
,
'z'
,
'w'
};
ICHECK
(
i
>=
0
&&
i
<
(
t
.
bits
()
==
8
?
16
:
(
t
.
bits
()
==
16
||
t
.
bits
()
==
32
)
?
8
:
4
));
ICHECK
(
i
>=
0
&&
i
<
(
t
.
bits
()
==
8
?
16
:
(
t
.
bits
()
==
16
||
t
.
bits
()
==
32
)
?
8
:
4
));
if
(
t
.
bits
()
==
8
&&
(
t
.
is_int
()
||
t
.
is_uint
()))
{
if
(
t
.
lanes
()
==
2
||
t
.
lanes
()
==
3
)
{
stream
<<
vec
<<
'.'
<<
access
[
i
%
t
.
lanes
()]
<<
"="
<<
"("
<<
value
<<
");
\n
"
;
stream
<<
vec
<<
'.'
<<
access
[
i
%
t
.
lanes
()]
<<
"="
<<
"("
<<
value
<<
");
\n
"
;
}
else
{
std
::
string
ac
=
t
.
lanes
()
==
4
?
vec
:
(
vec
+
"."
+
access
[
i
/
4
]);
stream
<<
ac
<<
"="
;
...
...
@@ -446,11 +476,11 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t,
stream
<<
"("
<<
value
<<
" << "
<<
i
%
4
*
8
<<
");
\n
"
;
}
}
else
if
(
t
.
is_float16
())
{
stream
<<
"((half2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
stream
<<
"((half2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
}
else
if
(
t
.
is_bfloat16
())
{
stream
<<
"((nv_bfloat162*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
stream
<<
"((nv_bfloat162*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
}
else
if
(
t
.
lanes
()
>
4
&&
t
.
lanes
()
<=
8
)
{
std
::
string
type_name
;
if
(
t
.
bits
()
==
16
)
{
...
...
@@ -469,15 +499,15 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t,
}
}
ICHECK
(
!
type_name
.
empty
());
stream
<<
"(("
<<
type_name
<<
"2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
stream
<<
"(("
<<
type_name
<<
"2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
}
else
{
stream
<<
vec
<<
"."
<<
access
[
i
]
<<
" = "
<<
value
<<
";
\n
"
;
}
}
void
CodeGenTileLangCUDA
::
PrintStorageSync
(
const
CallNode
*
op
)
{
const
std
::
string
&
sync
=
op
->
args
[
0
].
as
<
StringImmNode
>
()
->
value
;
void
CodeGenTileLangCUDA
::
PrintStorageSync
(
const
CallNode
*
op
)
{
const
std
::
string
&
sync
=
op
->
args
[
0
].
as
<
StringImmNode
>
()
->
value
;
if
(
sync
==
"warp"
)
{
// DO nothing.
}
else
if
(
sync
==
"shared"
||
sync
==
"shared.dyn"
)
{
...
...
@@ -486,9 +516,11 @@ void CodeGenTileLangCUDA::PrintStorageSync(const CallNode* op) {
}
}
void
CodeGenTileLangCUDA
::
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
{
// NOLINT(*)
ICHECK_NE
(
scope
,
"global"
)
<<
"Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead"
;
void
CodeGenTileLangCUDA
::
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
{
// NOLINT(*)
ICHECK_NE
(
scope
,
"global"
)
<<
"Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead"
;
if
(
scope
==
"shared"
)
{
os
<<
"__shared__ "
;
}
else
if
(
scope
==
"shared.dyn"
)
{
...
...
@@ -496,13 +528,16 @@ void CodeGenTileLangCUDA::PrintStorageScope(const std::string& scope, std::ostre
}
}
std
::
string
CodeGenTileLangCUDA
::
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
{
if
(
from
==
target
)
return
value
;
std
::
string
CodeGenTileLangCUDA
::
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
{
if
(
from
==
target
)
return
value
;
std
::
ostringstream
os
;
os
<<
"(("
;
this
->
PrintType
(
target
,
os
);
os
<<
")"
;
if
(
from
.
is_float16
()
&&
(
target
.
is_int
()
||
target
.
is_uint
())
&&
target
.
bits
()
==
8
)
{
if
(
from
.
is_float16
()
&&
(
target
.
is_int
()
||
target
.
is_uint
())
&&
target
.
bits
()
==
8
)
{
os
<<
"("
;
if
(
target
.
is_uint
())
{
os
<<
"u"
;
...
...
@@ -513,13 +548,14 @@ std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, Da
return
os
.
str
();
}
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
{
DataType
from_ty
=
op
->
value
.
dtype
();
DataType
target_ty
=
op
->
dtype
;
ICHECK_EQ
(
target_ty
.
lanes
(),
from_ty
.
lanes
());
// Emit simple C-style type conversion.
if
(
from_ty
.
is_scalar
())
return
CodeGenC
::
VisitExpr_
(
op
,
os
);
if
(
from_ty
.
is_scalar
())
return
CodeGenC
::
VisitExpr_
(
op
,
os
);
// We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops.
...
...
@@ -542,8 +578,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
os
<<
sret
;
}
void
CodeGenTileLangCUDA
::
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangCUDA
::
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
{
// NOLINT(*)
DataType
ret_dtype
=
GetRuntimeDataType
(
ret_type
);
if
(
ret_dtype
.
is_vector
())
{
//
...
...
@@ -583,7 +621,8 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, c
std
::
ostringstream
scall
;
scall
<<
global_symbol
<<
"("
;
for
(
size_t
j
=
0
;
j
<
sargs
.
size
();
++
j
)
{
if
(
j
>
0
)
scall
<<
", "
;
if
(
j
>
0
)
scall
<<
", "
;
PrintVecElemLoad
(
sargs
[
j
],
args
[
arg_begin
+
j
].
dtype
(),
i
,
scall
);
}
scall
<<
")"
;
...
...
@@ -592,13 +631,16 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, c
}
os
<<
sret
;
}
else
{
CodeGenC
::
PrintCallExtern
(
ret_type
,
global_symbol
,
args
,
skip_first_arg
,
os
);
CodeGenC
::
PrintCallExtern
(
ret_type
,
global_symbol
,
args
,
skip_first_arg
,
os
);
}
}
// Print a reference expression to a buffer.
std
::
string
CodeGenTileLangCUDA
::
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
{
const
VarNode
*
buffer_var
=
buffer
->
data
.
get
();
std
::
string
CodeGenTileLangCUDA
::
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
{
const
VarNode
*
buffer_var
=
buffer
->
data
.
get
();
std
::
ostringstream
os
;
std
::
string
vid
=
GetVarID
(
buffer_var
);
std
::
string
scope
;
...
...
@@ -654,12 +696,13 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, const BufferNode* buff
return
os
.
str
();
}
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
{
auto
print_extern_call_stmt
=
[
&
](
std
::
string
name
,
size_t
offset
=
0
)
{
this
->
PrintIndent
();
this
->
stream
<<
name
<<
"("
;
for
(
size_t
i
=
offset
;
i
<
op
->
args
.
size
();
i
++
)
{
if
(
i
>
offset
)
this
->
stream
<<
", "
;
if
(
i
>
offset
)
this
->
stream
<<
", "
;
this
->
stream
<<
this
->
PrintExpr
(
op
->
args
[
i
]);
}
this
->
stream
<<
");
\n
"
;
...
...
@@ -670,16 +713,18 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std
::
string
src
=
this
->
PrintExpr
(
op
->
args
[
2
]);
std
::
string
src_offset
=
this
->
PrintExpr
(
op
->
args
[
3
]);
std
::
string
size
=
this
->
PrintExpr
(
op
->
args
[
4
]);
// use size of argument list to indicate whether or not to use predicated cp.async
// use size of argument list to indicate whether or not to use predicated
// cp.async
if
(
op
->
args
.
size
()
==
5
)
{
this
->
PrintIndent
();
this
->
stream
<<
"tl::cp_async_gs<"
<<
size
<<
">("
<<
dst
<<
"+"
<<
dst_offset
<<
", "
<<
src
<<
"+"
<<
src_offset
<<
");
\n
"
;
this
->
stream
<<
"tl::cp_async_gs<"
<<
size
<<
">("
<<
dst
<<
"+"
<<
dst_offset
<<
", "
<<
src
<<
"+"
<<
src_offset
<<
");
\n
"
;
}
else
{
std
::
string
condition
=
this
->
PrintExpr
(
op
->
args
[
5
]);
this
->
PrintIndent
();
this
->
stream
<<
"tl::cp_async_gs_conditional<"
<<
size
<<
">("
<<
dst
<<
"+"
<<
dst_offset
<<
", "
<<
src
<<
"+"
<<
src_offset
<<
", "
<<
condition
<<
");
\n
"
;
this
->
stream
<<
"tl::cp_async_gs_conditional<"
<<
size
<<
">("
<<
dst
<<
"+"
<<
dst_offset
<<
", "
<<
src
<<
"+"
<<
src_offset
<<
", "
<<
condition
<<
");
\n
"
;
}
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_commit_group
()))
{
print_extern_call_stmt
(
"tl::cp_async_commit"
);
...
...
@@ -691,7 +736,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this
->
PrintIndent
();
int
barrier_count
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
barrier_name
=
"_mbarrier"
;
this
->
stream
<<
"__shared__ uint64_t "
<<
barrier_name
<<
"["
<<
barrier_count
<<
"];
\n
"
;
this
->
stream
<<
"__shared__ uint64_t "
<<
barrier_name
<<
"["
<<
barrier_count
<<
"];
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
GetMBarrierOp
()))
{
std
::
string
barrier_name
=
"_mbarrier"
;
std
::
string
barrier_id
=
this
->
PrintExpr
(
op
->
args
[
0
]);
...
...
@@ -720,13 +766,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
int
trans
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
int
num
=
Downcast
<
IntImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
func_name
=
"tl::ptx_ldmatrix_x"
+
std
::
to_string
(
num
);
if
(
trans
==
1
)
func_name
+=
"_trans"
;
if
(
trans
==
1
)
func_name
+=
"_trans"
;
print_extern_call_stmt
(
func_name
,
2
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
STMatrixOp
()))
{
int
trans
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
int
num
=
Downcast
<
IntImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
func_name
=
"tl::ptx_stmatrix_x"
+
std
::
to_string
(
num
);
if
(
trans
==
1
)
func_name
+=
"_trans"
;
if
(
trans
==
1
)
func_name
+=
"_trans"
;
print_extern_call_stmt
(
func_name
,
2
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
FenceProxyAsyncOp
()))
{
print_extern_call_stmt
(
"tl::fence_proxy_async"
);
...
...
@@ -734,15 +782,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this
->
PrintIndent
();
int
nreg
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
int
is_inc
=
Downcast
<
IntImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
func_name
=
is_inc
?
"tl::warpgroup_reg_alloc"
:
"tl::warpgroup_reg_dealloc"
;
std
::
string
func_name
=
is_inc
?
"tl::warpgroup_reg_alloc"
:
"tl::warpgroup_reg_dealloc"
;
this
->
stream
<<
func_name
<<
"<"
<<
std
::
to_string
(
nreg
)
<<
">();
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
WaitWgmma
()))
{
this
->
PrintIndent
();
int
num_mma
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
this
->
stream
<<
"tl::wait_wgmma<"
<<
std
::
to_string
(
num_mma
)
<<
">();
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
PackB16Op
()))
{
os
<<
"__pack_half2("
<<
this
->
PrintExpr
(
op
->
args
[
0
])
<<
", "
<<
this
->
PrintExpr
(
op
->
args
[
1
])
<<
")"
;
os
<<
"__pack_half2("
<<
this
->
PrintExpr
(
op
->
args
[
0
])
<<
", "
<<
this
->
PrintExpr
(
op
->
args
[
1
])
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
tvm_fill_fragment
()))
{
need_mma_h_
=
true
;
ICHECK_EQ
(
op
->
args
.
size
(),
6U
);
...
...
@@ -776,7 +825,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this
->
PrintExpr
(
op
->
args
[
4
],
os
);
os
<<
"], "
;
this
->
PrintExpr
(
op
->
args
[
6
],
os
);
if
(
const
StringImmNode
*
str
=
op
->
args
[
7
].
as
<
StringImmNode
>
())
{
if
(
const
StringImmNode
*
str
=
op
->
args
[
7
].
as
<
StringImmNode
>
())
{
os
<<
", nvcuda::wmma::mem_"
<<
str
->
value
;
}
else
{
LOG
(
FATAL
)
<<
"Invalid parameters"
;
...
...
@@ -831,10 +880,11 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std
::
string
c_ref
=
this
->
PrintExpr
(
op
->
args
[
10
]);
std
::
string
c_bias
=
this
->
PrintExpr
(
op
->
args
[
11
]);
bool
saturate
=
Downcast
<
Bool
>
(
op
->
args
[
12
])
->
value
;
std
::
string
bit_op
=
op
->
args
.
size
()
>
13
?
Downcast
<
StringImm
>
(
op
->
args
[
13
])
->
value
:
""
;
std
::
string
asm_code
=
PrintMMAAssembly
(
shape
,
A_layout
,
B_layout
,
A_dtype
,
B_dtype
,
C_dtype
,
a_ref
,
a_bias
,
b_ref
,
b_bias
,
c_ref
,
c_bias
,
""
,
""
,
""
,
bit_op
,
false
,
saturate
);
std
::
string
bit_op
=
op
->
args
.
size
()
>
13
?
Downcast
<
StringImm
>
(
op
->
args
[
13
])
->
value
:
""
;
std
::
string
asm_code
=
PrintMMAAssembly
(
shape
,
A_layout
,
B_layout
,
A_dtype
,
B_dtype
,
C_dtype
,
a_ref
,
a_bias
,
b_ref
,
b_bias
,
c_ref
,
c_bias
,
""
,
""
,
""
,
bit_op
,
false
,
saturate
);
this
->
stream
<<
asm_code
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_mma_sp
()))
{
...
...
@@ -872,8 +922,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std
::
string
sparse_selector
=
this
->
PrintExpr
(
op
->
args
[
14
]);
bool
saturate
=
Downcast
<
Bool
>
(
op
->
args
[
15
])
->
value
;
std
::
string
asm_code
=
PrintMMAAssembly
(
shape
,
A_layout
,
B_layout
,
A_dtype
,
B_dtype
,
C_dtype
,
a_ref
,
a_offset
,
b_ref
,
b_offset
,
c_ref
,
c_offset
,
metadata
,
metadata_offset
,
sparse_selector
,
""
,
true
,
saturate
);
shape
,
A_layout
,
B_layout
,
A_dtype
,
B_dtype
,
C_dtype
,
a_ref
,
a_offset
,
b_ref
,
b_offset
,
c_ref
,
c_offset
,
metadata
,
metadata_offset
,
sparse_selector
,
""
,
true
,
saturate
);
this
->
stream
<<
asm_code
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_ldmatrix
()))
{
// arg 0: whether the matrix is loaded in column major format or not.
...
...
@@ -882,7 +933,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// arg 3: pointer to local buffer.
// arg 4: The offset of the element to store in the local buffer.
// arg 5: pointer to the shared memory buffer to load.
// arg 6: The offset of the start element of the row to load in shared memory.
// arg 6: The offset of the start element of the row to load in shared
// memory.
ICHECK_EQ
(
op
->
args
.
size
(),
7U
);
bool
trans
=
Downcast
<
Bool
>
(
op
->
args
[
0
])
->
value
;
int
num
=
Downcast
<
Integer
>
(
op
->
args
[
1
])
->
value
;
...
...
@@ -891,20 +943,23 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std
::
string
local_elem_offset
=
this
->
PrintExpr
(
op
->
args
[
4
]);
std
::
string
smem_ptr
=
this
->
PrintExpr
(
op
->
args
[
5
]);
if
(
trans
&&
op
->
dtype
.
bits
()
==
8
)
{
// Since ldmatrix assumes that a matrix element is 16 bit, it cannot
properly transpose an
// int8 matrix.
// Since ldmatrix assumes that a matrix element is 16 bit, it cannot
//
properly transpose an
int8 matrix.
std
::
string
smem_stride
=
this
->
PrintExpr
(
op
->
args
[
6
]);
ICHECK
(
num
==
4
);
os
<<
"for (int i = 0; i < 16; ++i) {
\n
"
;
os
<<
local_ptr
<<
"["
+
local_elem_offset
+
" + i] = "
<<
smem_ptr
<<
"[(i % 8) / 4 * "
+
smem_stride
+
" * 16 + (threadIdx.x % 4) * 4 * "
+
smem_stride
+
"+ (i % 4) * "
+
smem_stride
+
" + threadIdx.x / 4 + (i / 8) * 8];
\n
"
;
<<
"[(i % 8) / 4 * "
+
smem_stride
+
" * 16 + (threadIdx.x % 4) * 4 * "
+
smem_stride
+
"+ (i % 4) * "
+
smem_stride
+
" + threadIdx.x / 4 + (i / 8) * 8];
\n
"
;
os
<<
"}
\n
"
;
}
else
{
std
::
string
smem_elem_offset
=
this
->
PrintExpr
(
op
->
args
[
6
]);
need_cast_smem_ptr_to_int_
=
true
;
this
->
stream
<<
PrintLoadMatrixAssembly
(
trans
,
num
,
type
,
local_ptr
,
local_elem_offset
,
smem_ptr
,
smem_elem_offset
);
this
->
stream
<<
PrintLoadMatrixAssembly
(
trans
,
num
,
type
,
local_ptr
,
local_elem_offset
,
smem_ptr
,
smem_elem_offset
);
}
}
else
if
(
op
->
op
.
same_as
(
builtin
::
mma_store
()))
{
int
m
=
Downcast
<
Integer
>
(
op
->
args
[
0
])
->
value
;
...
...
@@ -914,29 +969,31 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std
::
string
src_offset
=
this
->
PrintExpr
(
op
->
args
[
4
]);
PrimExpr
stride
=
op
->
args
[
5
];
ICHECK
(
m
==
16
&&
n
==
16
)
<<
"Only m == 16 && n == 16 case supported for now"
;
ICHECK
(
m
==
16
&&
n
==
16
)
<<
"Only m == 16 && n == 16 case supported for now"
;
// Each thread in a warp holds a certain number of elements of an MMA output.
// For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements
// in its registers. So conceptually, a warp memory is organized as a 32x8 block.
// A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below.
// Each thread in a warp holds a certain number of elements of an MMA
// output. For example, if we compute a 16x16 tile using MMA, each thread
// holds 8 elements in its registers. So conceptually, a warp memory is
// organized as a 32x8 block. A map from a 16x16 tile to a 32x8 block of
// memory is specified by the index map below.
// To store the 32x8 output back to a 16x16 tile in shared or global memory,
we invert this map
// to determine the output location for each 8 element.
// To store the 32x8 output back to a 16x16 tile in shared or global memory,
//
we invert this map
to determine the output location for each 8 element.
const
auto
*
index_map_func
=
const
auto
*
index_map_func
=
runtime
::
Registry
::
Get
(
"tir.index_map.shared_16x16_to_mma_32x8_layout"
);
IndexMap
index_map
;
if
(
!
index_map_func
)
{
Var
i
,
j
;
// The index map is defined as follows:
index_map
=
IndexMap
(
{
i
,
j
},
{
4
*
FloorMod
(
i
,
8
)
+
FloorDiv
(
FloorMod
(
j
,
8
),
2
),
4
*
FloorDiv
(
j
,
8
)
+
FloorDiv
(
i
,
8
)
*
2
+
FloorMod
(
j
,
2
)
});
}
else
{
index_map
=
IndexMap
::
FromFunc
(
2
,
*
index_map_func
);
index_map
=
IndexMap
(
{
i
,
j
},
{
4
*
FloorMod
(
i
,
8
)
+
FloorDiv
(
FloorMod
(
j
,
8
),
2
),
4
*
FloorDiv
(
j
,
8
)
+
FloorDiv
(
i
,
8
)
*
2
+
FloorMod
(
j
,
2
)
});
}
else
{
index_map
=
IndexMap
::
FromFunc
(
2
,
*
index_map_func
);
}
arith
::
Analyzer
analyzer
;
...
...
@@ -944,20 +1001,21 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
index_map
.
Inverse
({
Range
(
0
,
m
),
Range
(
0
,
n
)},
&
analyzer
);
auto
indices_16x16
=
inverse_index_map
->
final_indices
;
// "//" and "%" in the index map are translated to FloorDiv/Mod, but the
plain Div/Mod are fine.
// FloorDiv/Mod are supposed to be lowered before
they reach codegen, so manually replace them
// to the plain ones here.
// "//" and "%" in the index map are translated to FloorDiv/Mod, but the
//
plain Div/Mod are fine.
FloorDiv/Mod are supposed to be lowered before
//
they reach codegen, so manually replace them
to the plain ones here.
class
LowerFloorDivMod
:
public
ExprMutator
{
public:
PrimExpr
VisitExpr_
(
const
FloorDivNode
*
op
)
{
public:
PrimExpr
VisitExpr_
(
const
FloorDivNode
*
op
)
{
return
tir
::
Div
(
this
->
VisitExpr
(
op
->
a
),
this
->
VisitExpr
(
op
->
b
));
}
PrimExpr
VisitExpr_
(
const
FloorModNode
*
op
)
{
PrimExpr
VisitExpr_
(
const
FloorModNode
*
op
)
{
return
tir
::
Mod
(
this
->
VisitExpr
(
op
->
a
),
this
->
VisitExpr
(
op
->
b
));
}
};
auto
dst_ind
=
LowerFloorDivMod
()(
indices_16x16
[
0
]
*
stride
+
indices_16x16
[
1
]);
auto
dst_ind
=
LowerFloorDivMod
()(
indices_16x16
[
0
]
*
stride
+
indices_16x16
[
1
]);
var_idmap_
[
inverse_index_map
->
initial_indices
[
0
].
get
()]
=
"threadIdx.x"
;
var_idmap_
[
inverse_index_map
->
initial_indices
[
1
].
get
()]
=
"local_id"
;
...
...
@@ -967,8 +1025,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
<<
" = "
<<
"*((uint *)&"
<<
src
<<
"["
<<
src_offset
<<
" + local_id]);
\n
"
;
os
<<
"}
\n
"
;
}
else
{
}
else
{
os
<<
"for (int local_id = 0; local_id < 8; ++local_id) {
\n
"
;
os
<<
dst
<<
"["
+
this
->
PrintExpr
(
dst_ind
)
+
"]"
<<
" = "
<<
src
<<
"["
<<
src_offset
<<
" + local_id];
\n
"
;
...
...
@@ -990,12 +1047,14 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std
::
string
src_offset
=
this
->
PrintExpr
(
op
->
args
[
3
]);
std
::
string
size
=
this
->
PrintExpr
(
op
->
args
[
4
]);
need_cast_smem_ptr_to_int_
=
true
;
// use size of argument list to indicate whether or not to use predicated cp.async
// use size of argument list to indicate whether or not to use predicated
// cp.async
if
(
op
->
args
.
size
()
==
5
)
{
this
->
stream
<<
PrintCpAsyncAssembly
(
dst
,
dst_offset
,
src
,
src_offset
,
size
);
this
->
stream
<<
PrintCpAsyncAssembly
(
dst
,
dst_offset
,
src
,
src_offset
,
size
);
}
else
{
this
->
stream
<<
PrintPredicatedCpAsyncAssembly
(
dst
,
dst_offset
,
src
,
src_offset
,
size
,
this
->
PrintExpr
(
op
->
args
[
5
]));
this
->
stream
<<
PrintPredicatedCpAsyncAssembly
(
dst
,
dst_offset
,
src
,
src_offset
,
size
,
this
->
PrintExpr
(
op
->
args
[
5
]));
}
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_cp_async_bulk
()))
{
need_cast_smem_ptr_to_int_
=
true
;
...
...
@@ -1006,44 +1065,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std
::
string
size
=
this
->
PrintExpr
(
op
->
args
[
4
]);
int
barrier_id
=
Downcast
<
IntImm
>
(
op
->
args
[
5
])
->
value
;
CHECK
(
barrier_id
<
barrier_count_
);
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
this
->
stream
<<
PrintCpAsyncBulkAsm
(
dst
,
dst_offset
,
src
,
src_offset
,
size
,
barrier
);
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
this
->
stream
<<
PrintCpAsyncBulkAsm
(
dst
,
dst_offset
,
src
,
src_offset
,
size
,
barrier
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_commit_group
()))
{
this
->
stream
<<
"__asm__ __volatile__(
\"
cp.async.commit_group;
\"
);
\n\n
"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_wait_group
()))
{
int
n
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
this
->
stream
<<
"__asm__ __volatile__(
\"
cp.async.wait_group "
<<
n
<<
";
\"
);
\n\n
"
;
this
->
stream
<<
"__asm__ __volatile__(
\"
cp.async.wait_group "
<<
n
<<
";
\"
);
\n\n
"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_cp_async_barrier
()))
{
need_cast_smem_ptr_to_int_
=
true
;
int
barrier_id
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
CHECK
(
barrier_id
<
barrier_count_
);
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
this
->
stream
<<
PrintCpAsyncBarrierAsm
(
barrier
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_init_barrier_thread_count
()))
{
need_cast_smem_ptr_to_int_
=
true
;
int
barrier_id
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
CHECK
(
barrier_id
<
barrier_count_
);
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
std
::
string
thread_count
=
this
->
PrintExpr
(
op
->
args
[
1
]);
this
->
stream
<<
PrintInitBarrierThreadCountAsm
(
barrier
,
thread_count
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_arrive_barrier
()))
{
need_cast_smem_ptr_to_int_
=
true
;
int
barrier_id
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
CHECK
(
barrier_id
<
barrier_count_
);
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
this
->
stream
<<
PrintArriveBarrierAsm
(
barrier
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_arrive_barrier_expect_tx
()))
{
need_cast_smem_ptr_to_int_
=
true
;
int
barrier_id
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
CHECK
(
barrier_id
<
barrier_count_
);
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
std
::
string
byte_count
=
this
->
PrintExpr
(
op
->
args
[
1
]);
this
->
stream
<<
PrintArriveBarrierExpectTxAsm
(
barrier
,
byte_count
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_wait_barrier
()))
{
need_cast_smem_ptr_to_int_
=
true
;
int
barrier_id
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
CHECK
(
barrier_id
<
barrier_count_
);
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
std
::
string
barrier
=
barrier_name_
+
"["
+
std
::
to_string
(
barrier_id
)
+
"]"
;
this
->
stream
<<
PrintWaitBarrierAsm
(
barrier
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
create_barriers
()))
{
CHECK_EQ
(
barrier_count_
,
-
1
);
...
...
@@ -1052,13 +1119,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
CHECK_EQ
(
barrier_alignment_bytes_
%
sizeof
(
uint64_t
),
0
);
int
barrier_alignment_count
=
barrier_alignment_bytes_
/
sizeof
(
uint64_t
);
if
(
barrier_count
%
barrier_alignment_count
!=
0
)
{
barrier_count
=
((
barrier_count
/
barrier_alignment_count
)
+
1
)
*
barrier_alignment_count
;
barrier_count
=
((
barrier_count
/
barrier_alignment_count
)
+
1
)
*
barrier_alignment_count
;
}
barrier_count_
=
barrier_count
;
this
->
stream
<<
"__shared__ __align__("
<<
barrier_alignment_bytes_
<<
") uint64_t "
<<
barrier_name_
<<
"["
<<
barrier_count
<<
"];
\n
"
;
this
->
stream
<<
"for (int i = 0; i < "
<<
barrier_count
<<
"; ++i) { "
<<
barrier_name_
<<
"[i] = 0; }
\n
"
;
this
->
stream
<<
"__shared__ __align__("
<<
barrier_alignment_bytes_
<<
") uint64_t "
<<
barrier_name_
<<
"["
<<
barrier_count
<<
"];
\n
"
;
this
->
stream
<<
"for (int i = 0; i < "
<<
barrier_count
<<
"; ++i) { "
<<
barrier_name_
<<
"[i] = 0; }
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_ldg32
()))
{
/*
asm volatile (
...
...
@@ -1075,7 +1144,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std
::
string
reg
=
this
->
PrintExpr
(
op
->
args
[
0
]);
// get guard
std
::
string
guard
=
this
->
PrintExpr
(
op
->
args
[
1
]);
const
BufferLoadNode
*
addr_buffer
=
op
->
args
[
2
].
as
<
BufferLoadNode
>
();
const
BufferLoadNode
*
addr_buffer
=
op
->
args
[
2
].
as
<
BufferLoadNode
>
();
std
::
string
global_addr
=
this
->
PrintExpr
(
addr_buffer
->
indices
[
0
]);
std
::
string
global_buffer
=
this
->
PrintExpr
(
addr_buffer
->
buffer
->
data
);
std
::
string
local_addr
=
this
->
PrintExpr
(
op
->
args
[
3
]);
...
...
@@ -1087,26 +1156,27 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
stream
<<
":
\"
=f
\"
("
<<
reg
<<
"["
<<
local_addr
<<
"]"
<<
")
\n
"
;
stream
<<
":
\"
l
\"
((void*)("
<<
global_buffer
<<
"+"
<<
global_addr
<<
")),
\"
r
\"
((int)"
<<
guard
<<
")
\n
"
;
stream
<<
":
\"
l
\"
((void*)("
<<
global_buffer
<<
"+"
<<
global_addr
<<
")),
\"
r
\"
((int)"
<<
guard
<<
")
\n
"
;
stream
<<
");
\n
"
;
}
else
{
CodeGenC
::
VisitExpr_
(
op
,
os
);
}
}
void
CodeGenTileLangCUDA
::
VisitStmt_
(
const
AttrStmtNode
*
op
)
{
void
CodeGenTileLangCUDA
::
VisitStmt_
(
const
AttrStmtNode
*
op
)
{
if
(
op
->
attr_key
==
tir
::
attr
::
fragment_shape
)
{
const
VarNode
*
buffer
=
op
->
node
.
as
<
VarNode
>
();
const
StringImmNode
*
shape_str
=
op
->
value
.
as
<
StringImmNode
>
();
const
VarNode
*
buffer
=
op
->
node
.
as
<
VarNode
>
();
const
StringImmNode
*
shape_str
=
op
->
value
.
as
<
StringImmNode
>
();
fragment_shapes
[
buffer
]
=
shape_str
->
value
;
}
else
if
(
op
->
attr_key
==
tir
::
attr
::
fragment_layout
)
{
const
VarNode
*
buffer
=
op
->
node
.
as
<
VarNode
>
();
const
StringImmNode
*
layout_str
=
op
->
value
.
as
<
StringImmNode
>
();
const
VarNode
*
buffer
=
op
->
node
.
as
<
VarNode
>
();
const
StringImmNode
*
layout_str
=
op
->
value
.
as
<
StringImmNode
>
();
fragment_layouts
[
buffer
]
=
layout_str
->
value
;
}
else
if
(
op
->
attr_key
==
tir
::
attr
::
async_commit_queue_scope
)
{
const
IntImmNode
*
queue_id
=
op
->
value
.
as
<
IntImmNode
>
();
ICHECK
(
queue_id
&&
queue_id
->
value
==
0
)
<<
"For CUDA, the index of an async queue must be 0."
;
const
IntImmNode
*
queue_id
=
op
->
value
.
as
<
IntImmNode
>
();
ICHECK
(
queue_id
&&
queue_id
->
value
==
0
)
<<
"For CUDA, the index of an async queue must be 0."
;
this
->
VisitStmt
(
op
->
body
);
auto
commit_group
=
Call
(
DataType
::
Void
(),
builtin
::
ptx_commit_group
(),
{});
this
->
VisitExpr
(
commit_group
,
this
->
stream
);
...
...
@@ -1114,9 +1184,11 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
}
else
if
(
op
->
attr_key
==
tir
::
attr
::
async_wait_queue_scope
)
{
auto
wait_attrs
=
GetAsyncWaitAttributes
(
op
);
auto
queue_id
=
wait_attrs
.
first
.
as
<
IntImmNode
>
();
ICHECK
(
queue_id
&&
queue_id
->
value
==
0
)
<<
"For CUDA, the index of an async queue must be 0."
;
ICHECK
(
queue_id
&&
queue_id
->
value
==
0
)
<<
"For CUDA, the index of an async queue must be 0."
;
auto
wait_cnt
=
wait_attrs
.
second
;
auto
wait_group
=
Call
(
DataType
::
Void
(),
builtin
::
ptx_wait_group
(),
{
wait_cnt
});
auto
wait_group
=
Call
(
DataType
::
Void
(),
builtin
::
ptx_wait_group
(),
{
wait_cnt
});
this
->
VisitExpr
(
wait_group
,
this
->
stream
);
auto
inner
=
op
->
body
.
as
<
AttrStmtNode
>
();
ICHECK
(
inner
);
...
...
@@ -1124,7 +1196,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
return
;
}
else
if
(
op
->
attr_key
==
"threadblock_swizzle_pattern"
)
{
this
->
PrintIndent
();
const
StringImmNode
*
pattern
=
op
->
value
.
as
<
StringImmNode
>
();
const
StringImmNode
*
pattern
=
op
->
value
.
as
<
StringImmNode
>
();
ICHECK
(
pattern
);
this
->
stream
<<
"const dim3 blockIdx = "
<<
pattern
->
value
<<
"();
\n
"
;
this
->
VisitStmt
(
op
->
body
);
...
...
@@ -1133,28 +1205,28 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
CodeGenC
::
VisitStmt_
(
op
);
}
void
CodeGenTileLangCUDA
::
VisitStmt_
(
const
AllocateNode
*
op
)
{
void
CodeGenTileLangCUDA
::
VisitStmt_
(
const
AllocateNode
*
op
)
{
ICHECK
(
!
is_zero
(
op
->
condition
));
std
::
string
vid
=
AllocVarID
(
op
->
buffer_var
.
get
());
this
->
PrintIndent
();
std
::
string
scope
=
GetPtrStorageScope
(
op
->
buffer_var
);
const
VarNode
*
buffer
=
op
->
buffer_var
.
as
<
VarNode
>
();
const
VarNode
*
buffer
=
op
->
buffer_var
.
as
<
VarNode
>
();
if
(
scope
.
find
(
"wmma."
)
==
0
)
{
if
(
scope
==
"wmma.matrix_a"
||
scope
==
"wmma.matrix_b"
)
{
ICHECK
(
op
->
dtype
==
DataType
::
Float
(
16
)
||
op
->
dtype
==
DataType
::
Int
(
8
)
||
op
->
dtype
==
DataType
::
U
Int
(
8
)
||
op
->
dtype
==
DataType
::
Int
(
4
)
||
op
->
dtype
==
DataType
::
U
Int
(
4
)
||
op
->
dtype
==
DataType
::
Int
(
1
)
||
op
->
dtype
==
DataType
::
BFloat
(
16
))
ICHECK
(
op
->
dtype
==
DataType
::
Float
(
16
)
||
op
->
dtype
==
DataType
::
Int
(
8
)
||
op
->
dtype
==
DataType
::
U
Int
(
8
)
||
op
->
dtype
==
DataType
::
Int
(
4
)
||
op
->
dtype
==
DataType
::
U
Int
(
4
)
||
op
->
dtype
==
DataType
::
Int
(
1
)
||
op
->
dtype
==
DataType
::
BFloat
(
16
))
<<
"Matrix_a and matrix_b only support half or char or unsigned char "
<<
"or uint4 or int4 or int1 type for now"
;
}
else
{
ICHECK
(
op
->
dtype
==
DataType
::
Float
(
16
)
||
op
->
dtype
==
DataType
::
Float
(
32
)
||
op
->
dtype
==
DataType
::
Int
(
32
))
ICHECK
(
op
->
dtype
==
DataType
::
Float
(
16
)
||
op
->
dtype
==
DataType
::
Float
(
32
)
||
op
->
dtype
==
DataType
::
Int
(
32
))
<<
"Accumulator only support half, float and int type for now"
;
}
PrintWmmaScope
(
scope
,
op
->
dtype
,
buffer
,
stream
);
}
else
{
}
else
{
PrintStorageScope
(
scope
,
stream
);
PrintType
(
op
->
dtype
,
stream
);
}
...
...
@@ -1163,7 +1235,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) {
stream
<<
' '
<<
vid
<<
"[];
\n
"
;
}
else
{
size_t
constant_size
=
op
->
ConstantAllocationSize
();
ICHECK_GT
(
constant_size
,
0
)
<<
"Can only handle constant size stack allocation for now"
;
ICHECK_GT
(
constant_size
,
0
)
<<
"Can only handle constant size stack allocation for now"
;
if
(
scope
.
find
(
"wmma."
)
==
0
)
{
constant_size
=
GetWmmaFragmentSize
(
scope
,
buffer
,
constant_size
);
}
...
...
@@ -1179,7 +1252,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) {
this
->
PrintStmt
(
op
->
body
);
}
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
{
int
lanes
=
static_cast
<
int
>
(
Downcast
<
IntImm
>
(
op
->
lanes
)
->
value
);
CHECK_LE
(
lanes
,
4
)
<<
"ValueError: Ramp of more than 4 lanes is not allowed."
;
os
<<
"(make_"
;
...
...
@@ -1188,16 +1261,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
for
(
int
i
=
0
;
i
<
lanes
;
i
++
)
{
os
<<
"("
<<
PrintExpr
(
op
->
base
)
<<
")"
<<
"+("
<<
PrintExpr
(
op
->
stride
)
<<
"*"
<<
i
<<
")"
;
if
(
i
!=
lanes
-
1
)
os
<<
", "
;
if
(
i
!=
lanes
-
1
)
os
<<
", "
;
}
os
<<
"))"
;
}
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
int
lanes
=
static_cast
<
int
>
(
Downcast
<
IntImm
>
(
op
->
lanes
)
->
value
);
if
((
op
->
dtype
.
is_int
()
||
op
->
dtype
.
is_uint
())
&&
op
->
dtype
.
bits
()
==
8
&&
lanes
==
4
)
{
if
((
op
->
dtype
.
is_int
()
||
op
->
dtype
.
is_uint
())
&&
op
->
dtype
.
bits
()
==
8
&&
lanes
==
4
)
{
// make_int8x4
const
int64_t
*
p
=
as_const_int
(
op
->
value
);
const
int64_t
*
p
=
as_const_int
(
op
->
value
);
ICHECK
(
p
);
int64_t
v
=
*
p
&
0xFF
;
v
=
(
v
<<
24
)
|
(
v
<<
16
)
|
(
v
<<
8
)
|
v
;
...
...
@@ -1215,7 +1291,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType
(
op
->
dtype
,
os
);
os
<<
'('
;
for
(
int
i
=
0
;
i
<
lanes
/
2
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
if
(
i
!=
0
)
os
<<
", "
;
os
<<
"__pack_half2("
<<
v
<<
", "
<<
v
<<
")"
;
}
os
<<
')'
;
...
...
@@ -1228,18 +1305,21 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType
(
op
->
dtype
,
os
);
os
<<
'('
;
for
(
int
i
=
0
;
i
<
lanes
/
2
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
if
(
i
!=
0
)
os
<<
", "
;
os
<<
"__pack_nv_bfloat162("
<<
v
<<
", "
<<
v
<<
")"
;
}
os
<<
')'
;
return
;
}
if
(
op
->
dtype
.
is_float
()
&&
op
->
dtype
.
bits
()
==
32
&&
op
->
dtype
.
lanes
()
==
8
)
{
if
(
op
->
dtype
.
is_float
()
&&
op
->
dtype
.
bits
()
==
32
&&
op
->
dtype
.
lanes
()
==
8
)
{
std
::
string
v
=
PrintExpr
(
op
->
value
);
os
<<
"make_ulonglong4("
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
if
(
i
!=
0
)
os
<<
", "
;
os
<<
"*(unsigned long long*)&make_float2("
<<
v
<<
", "
<<
v
<<
")"
;
}
os
<<
')'
;
...
...
@@ -1248,7 +1328,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
if
((
op
->
dtype
.
is_int
()
||
op
->
dtype
.
is_uint
())
&&
op
->
dtype
.
bits
()
==
4
)
{
bool
fail
=
false
;
const
int64_t
*
p
=
as_const_int
(
op
->
value
);
const
int64_t
*
p
=
as_const_int
(
op
->
value
);
ICHECK
(
p
);
int64_t
v
=
*
p
&
0xF
;
...
...
@@ -1260,7 +1340,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
os
<<
"(int16_t)"
<<
v
;
}
}
else
{
v
=
(
v
<<
28
)
|
(
v
<<
24
)
|
(
v
<<
20
)
|
(
v
<<
16
)
|
(
v
<<
12
)
|
(
v
<<
8
)
|
(
v
<<
4
)
|
v
;
v
=
(
v
<<
28
)
|
(
v
<<
24
)
|
(
v
<<
20
)
|
(
v
<<
16
)
|
(
v
<<
12
)
|
(
v
<<
8
)
|
(
v
<<
4
)
|
v
;
if
(
lanes
==
8
)
{
if
(
op
->
dtype
.
is_uint
())
{
os
<<
"(uint)"
<<
v
;
...
...
@@ -1272,7 +1353,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType
(
op
->
dtype
,
os
);
os
<<
'('
;
for
(
int
i
=
0
;
i
<
lanes
/
8
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
if
(
i
!=
0
)
os
<<
", "
;
if
(
op
->
dtype
.
is_uint
())
{
os
<<
"(uint)"
<<
v
;
}
else
{
...
...
@@ -1295,13 +1377,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType
(
op
->
dtype
,
os
);
os
<<
'('
;
for
(
int
i
=
0
;
i
<
lanes
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
if
(
i
!=
0
)
os
<<
", "
;
os
<<
v
;
}
os
<<
')'
;
}
inline
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangCUDA
*
p
)
{
// NOLINT(*)
inline
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangCUDA
*
p
)
{
// NOLINT(*)
// Type code is kBFloat
if
(
op
->
dtype
.
is_bfloat16
())
{
os
<<
"bfloat16_t"
;
...
...
@@ -1310,46 +1394,49 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLang
}
// Type code is kFloat
switch
(
op
->
dtype
.
bits
())
{
case
64
:
case
32
:
{
std
::
ostringstream
temp
;
if
(
std
::
isinf
(
op
->
value
))
{
if
(
op
->
value
<
0
)
{
temp
<<
"-"
;
}
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"CUDART_INF_F"
:
"CUDART_INF"
);
}
else
if
(
std
::
isnan
(
op
->
value
))
{
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"CUDART_NAN_F"
:
"CUDART_NAN"
);
}
else
{
temp
<<
std
::
scientific
<<
op
->
value
;
if
(
op
->
dtype
.
bits
()
==
32
)
temp
<<
'f'
;
case
64
:
case
32
:
{
std
::
ostringstream
temp
;
if
(
std
::
isinf
(
op
->
value
))
{
if
(
op
->
value
<
0
)
{
temp
<<
"-"
;
}
p
->
MarkConst
(
temp
.
str
());
os
<<
temp
.
str
();
break
;
}
case
16
:
{
os
<<
"half_t"
<<
'('
;
FloatImm
const_f32
=
FloatImm
(
DataType
::
Float
(
32
),
op
->
value
);
PrintConst
(
const_f32
.
get
(),
os
,
p
);
os
<<
')'
;
break
;
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"CUDART_INF_F"
:
"CUDART_INF"
);
}
else
if
(
std
::
isnan
(
op
->
value
))
{
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"CUDART_NAN_F"
:
"CUDART_NAN"
);
}
else
{
temp
<<
std
::
scientific
<<
op
->
value
;
if
(
op
->
dtype
.
bits
()
==
32
)
temp
<<
'f'
;
}
default:
LOG
(
FATAL
)
<<
"Bad bit-width for float: "
<<
op
->
dtype
<<
"
\n
"
;
p
->
MarkConst
(
temp
.
str
());
os
<<
temp
.
str
();
break
;
}
case
16
:
{
os
<<
"half_t"
<<
'('
;
FloatImm
const_f32
=
FloatImm
(
DataType
::
Float
(
32
),
op
->
value
);
PrintConst
(
const_f32
.
get
(),
os
,
p
);
os
<<
')'
;
break
;
}
default:
LOG
(
FATAL
)
<<
"Bad bit-width for float: "
<<
op
->
dtype
<<
"
\n
"
;
}
}
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintConst
(
op
,
os
,
this
);
}
void
CodeGenTileLangCUDA
::
PrintWmmaScope
(
const
std
::
string
&
scope
,
DataType
t
,
const
VarNode
*
variable
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangCUDA
::
PrintWmmaScope
(
const
std
::
string
&
scope
,
DataType
t
,
const
VarNode
*
variable
,
std
::
ostream
&
os
)
{
std
::
stringstream
type
;
PrintType
(
t
,
type
);
ICHECK
(
fragment_shapes
.
count
(
variable
))
<<
"Cannot find shape of the wmma fragment "
<<
variable
->
name_hint
;
ICHECK
(
fragment_shapes
.
count
(
variable
))
<<
"Cannot find shape of the wmma fragment "
<<
variable
->
name_hint
;
std
::
string
shape_str
=
fragment_shapes
.
at
(
variable
);
if
((
t
.
is_int
()
||
t
.
is_uint
())
&&
t
.
bits
()
<
8
&&
t
.
lanes
()
==
1
)
{
type
.
str
(
std
::
string
());
...
...
@@ -1372,23 +1459,24 @@ void CodeGenTileLangCUDA::PrintWmmaScope(const std::string& scope, DataType t,
if
(
scope
==
"wmma.matrix_a"
)
{
std
::
string
layout_str
=
fragment_layouts
[
variable
];
ICHECK_NE
(
layout_str
,
""
)
<<
"Layout must be defined for matrix_a"
;
os
<<
"nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, "
<<
shape_str
<<
", "
<<
type
.
str
()
<<
", nvcuda::wmma::"
<<
layout_str
<<
">"
;
os
<<
"nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, "
<<
shape_str
<<
", "
<<
type
.
str
()
<<
", nvcuda::wmma::"
<<
layout_str
<<
">"
;
}
else
if
(
scope
==
"wmma.matrix_b"
)
{
std
::
string
layout_str
=
fragment_layouts
[
variable
];
ICHECK_NE
(
layout_str
,
""
)
<<
"Layout must be defined for matrix_b"
;
os
<<
"nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, "
<<
shape_str
<<
", "
<<
type
.
str
()
<<
", nvcuda::wmma::"
<<
layout_str
<<
">"
;
os
<<
"nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, "
<<
shape_str
<<
", "
<<
type
.
str
()
<<
", nvcuda::wmma::"
<<
layout_str
<<
">"
;
}
else
if
(
scope
==
"wmma.accumulator"
)
{
os
<<
"nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
<<
shape_str
<<
", "
<<
type
.
str
()
<<
">"
;
os
<<
"nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
<<
shape_str
<<
", "
<<
type
.
str
()
<<
">"
;
}
}
int32_t
CodeGenTileLangCUDA
::
GetWmmaFragmentSize
(
const
std
::
string
&
scope
,
const
VarNode
*
variable
,
int32_t
CodeGenTileLangCUDA
::
GetWmmaFragmentSize
(
const
std
::
string
&
scope
,
const
VarNode
*
variable
,
int32_t
size
)
{
ICHECK
(
fragment_shapes
.
count
(
variable
))
<<
"Cannot find shape of the wmma fragment "
<<
variable
->
name_hint
;
ICHECK
(
fragment_shapes
.
count
(
variable
))
<<
"Cannot find shape of the wmma fragment "
<<
variable
->
name_hint
;
std
::
string
shape_str
=
fragment_shapes
.
at
(
variable
);
std
::
pair
<
int32_t
,
int32_t
>
dim
=
GetWmmaFragmentDimSize
(
shape_str
,
scope
);
if
(
dim
.
first
*
dim
.
second
!=
0
)
...
...
@@ -1397,12 +1485,14 @@ int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string& scope, const
return
0
;
}
void
CodeGenTileLangCUDA
::
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangCUDA
::
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
{
// Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile.
//
if
((
op
->
dtype
.
is_float16
()
||
op
->
dtype
.
is_bfloat16
())
&&
IsVolatile
(
op
->
buffer
->
data
.
get
()))
{
if
((
op
->
dtype
.
is_float16
()
||
op
->
dtype
.
is_bfloat16
())
&&
IsVolatile
(
op
->
buffer
->
data
.
get
()))
{
os
<<
"("
;
PrintType
(
op
->
dtype
,
os
);
os
<<
")("
<<
value
<<
")"
;
...
...
@@ -1411,15 +1501,17 @@ void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string& value, const Bu
}
}
void
CodeGenTileLangCUDA
::
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangCUDA
::
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
{
ICHECK_GT
(
t
.
lanes
(),
1
);
if
(
t
.
bits
()
==
8
&&
(
t
.
is_int
()
||
t
.
is_uint
()))
{
if
(
!
(
t
.
lanes
()
==
2
||
t
.
lanes
()
==
3
))
{
if
(
i
!=
0
)
{
os
<<
"|"
;
}
os
<<
"((0x000000ff << "
<<
i
*
8
<<
") & ("
<<
value
<<
" << "
<<
i
*
8
<<
"))"
;
os
<<
"((0x000000ff << "
<<
i
*
8
<<
") & ("
<<
value
<<
" << "
<<
i
*
8
<<
"))"
;
return
;
}
}
...
...
@@ -1476,7 +1568,7 @@ void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::str
return
;
}
void
CodeGenTileLangCUDA
::
AddFunction
(
const
PrimFunc
&
f
)
{
void
CodeGenTileLangCUDA
::
AddFunction
(
const
PrimFunc
&
f
)
{
// clear previous generated state.
this
->
InitFuncState
(
f
);
// reserve keywords
...
...
@@ -1495,10 +1587,11 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) {
for
(
size_t
i
=
0
;
i
<
f
->
params
.
size
();
++
i
)
{
tir
::
Var
v
=
f
->
params
[
i
];
std
::
string
vid
=
AllocVarID
(
v
.
get
());
if
(
i
!=
0
)
stream
<<
", "
;
if
(
i
!=
0
)
stream
<<
", "
;
if
(
v
.
dtype
().
is_handle
())
{
// work around for grid constant parameters.
if
(
auto
*
ptr
=
v
->
type_annotation
.
as
<
PointerTypeNode
>
())
{
if
(
auto
*
ptr
=
v
->
type_annotation
.
as
<
PointerTypeNode
>
())
{
if
(
ptr
->
storage_scope
==
"grid_constant"
)
{
stream
<<
"__grid_constant__ const "
;
CodeGenC
::
PrintType
(
ptr
->
element_type
,
stream
);
...
...
@@ -1513,8 +1606,8 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) {
}
CodeGenC
::
PrintType
(
GetType
(
v
),
stream
);
if
(
auto
*
ptr
=
v
->
type_annotation
.
as
<
PointerTypeNode
>
())
{
if
(
auto
*
prim
=
ptr
->
element_type
.
as
<
PrimTypeNode
>
())
{
if
(
auto
*
ptr
=
v
->
type_annotation
.
as
<
PointerTypeNode
>
())
{
if
(
auto
*
prim
=
ptr
->
element_type
.
as
<
PrimTypeNode
>
())
{
RegisterHandleType
(
v
.
get
(),
prim
->
dtype
);
}
}
...
...
@@ -1536,5 +1629,5 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) {
this
->
stream
<<
"}
\n\n
"
;
}
}
// namespace codegen
}
// namespace tvm
}
// namespace codegen
}
// namespace tvm
src/target/codegen_cuda.h
View file @
549416f7
...
...
@@ -21,50 +21,58 @@ namespace tvm {
namespace
codegen
{
class
CodeGenTileLangCUDA
final
:
public
CodeGenC
{
public:
public:
CodeGenTileLangCUDA
();
std
::
string
Finish
();
// override behavior
void
PrintFuncPrefix
(
std
::
ostream
&
os
)
final
;
void
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
PrintStorageSync
(
const
CallNode
*
op
)
final
;
void
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
PrimExpr
rhs
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
const
std
::
string
&
value
)
final
;
void
BindThreadIndex
(
const
IterVar
&
iv
)
final
;
// NOLINT(*)
void
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
final
;
std
::
string
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
final
;
void
PrintFuncPrefix
(
std
::
ostream
&
os
)
final
;
void
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
PrintStorageSync
(
const
CallNode
*
op
)
final
;
void
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
PrimExpr
rhs
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
const
std
::
string
&
value
)
final
;
void
BindThreadIndex
(
const
IterVar
&
iv
)
final
;
// NOLINT(*)
void
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
final
;
std
::
string
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
final
;
// overload visitor
void
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
;
void
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
;
// Override this as a work around for __grid_constant__ parameter
void
AddFunction
(
const
PrimFunc
&
f
);
void
AddFunction
(
const
PrimFunc
&
f
);
protected:
virtual
std
::
string
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
final
;
void
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
protected:
virtual
std
::
string
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
final
;
void
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
private:
private:
// Handle volatile loads
void
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
final
;
void
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
final
;
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool
IsScopePartOfType
()
const
final
{
return
false
;
}
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangCUDA
*
p
);
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangCUDA
*
p
);
// The size of the barrier array in shared memory
int
barrier_count_
=
-
1
;
// whether need mma.h
...
...
@@ -77,15 +85,17 @@ class CodeGenTileLangCUDA final : public CodeGenC {
// Set to 16 to maintain minimum alignment requirements for async bulk copy
const
int
barrier_alignment_bytes_
=
16
;
std
::
unordered_map
<
const
VarNode
*
,
std
::
string
>
fragment_shapes
;
std
::
unordered_map
<
const
VarNode
*
,
std
::
string
>
fragment_layouts
;
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangCUDA
*
p
);
void
PrintWmmaScope
(
const
std
::
string
&
scope
,
DataType
t
,
const
VarNode
*
variable
,
std
::
ostream
&
os
);
int32_t
GetWmmaFragmentSize
(
const
std
::
string
&
scope
,
const
VarNode
*
variable
,
int32_t
size
);
std
::
unordered_map
<
const
VarNode
*
,
std
::
string
>
fragment_shapes
;
std
::
unordered_map
<
const
VarNode
*
,
std
::
string
>
fragment_layouts
;
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangCUDA
*
p
);
void
PrintWmmaScope
(
const
std
::
string
&
scope
,
DataType
t
,
const
VarNode
*
variable
,
std
::
ostream
&
os
);
int32_t
GetWmmaFragmentSize
(
const
std
::
string
&
scope
,
const
VarNode
*
variable
,
int32_t
size
);
};
}
// namespace codegen
}
// namespace tvm
}
// namespace codegen
}
// namespace tvm
#endif
// TVM_TL_TARGET_CODEGEN_CUDA_H_
#endif // TVM_TL_TARGET_CODEGEN_CUDA_H_
src/target/codegen_hip.cc
View file @
549416f7
...
...
@@ -6,9 +6,9 @@
*/
#include "codegen_hip.h"
#include <tvm/tir/index_map.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <cmath>
...
...
@@ -28,12 +28,13 @@ namespace codegen {
* \note should use std::format instead when codebase is ported to C++20.
*/
class
Replacer
{
public:
void
register_rule
(
const
std
::
string
&
pattern
,
const
std
::
string
&
replacement
)
{
public:
void
register_rule
(
const
std
::
string
&
pattern
,
const
std
::
string
&
replacement
)
{
_rules
.
emplace_back
(
pattern
,
replacement
);
}
std
::
string
rewrite
(
std
::
string
str
)
{
for
(
auto
&&
rule
:
_rules
)
{
for
(
auto
&&
rule
:
_rules
)
{
auto
[
pattern
,
replacement
]
=
rule
;
size_t
len
=
pattern
.
size
();
size_t
new_len
=
replacement
.
size
();
...
...
@@ -47,46 +48,53 @@ class Replacer {
}
void
empty_rules
()
{
_rules
.
clear
();
}
private:
private:
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
_rules
;
};
CodeGenTileLangHIP
::
CodeGenTileLangHIP
()
{
restrict_keyword_
=
"__restrict__"
;
}
void
CodeGenTileLangHIP
::
PrintFuncPrefix
(
std
::
ostream
&
os
)
{
os
<<
"extern
\"
C
\"
__global__ "
;
}
void
CodeGenTileLangHIP
::
PrintFuncPrefix
(
std
::
ostream
&
os
)
{
os
<<
"extern
\"
C
\"
__global__ "
;
}
class
LaunchConfigExtractor
:
public
tir
::
StmtVisitor
{
private:
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
private:
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
if
(
iv
->
var
->
name_hint
==
"threadIdx.x"
||
iv
->
thread_tag
==
"threadIdx.x"
)
{
if
(
iv
->
var
->
name_hint
==
"threadIdx.x"
||
iv
->
thread_tag
==
"threadIdx.x"
)
{
threadIdx_x_ext
=
op
->
value
;
}
else
if
(
iv
->
var
->
name_hint
==
"threadIdx.y"
||
iv
->
thread_tag
==
"threadIdx.y"
)
{
}
else
if
(
iv
->
var
->
name_hint
==
"threadIdx.y"
||
iv
->
thread_tag
==
"threadIdx.y"
)
{
threadIdx_y_ext
=
op
->
value
;
}
else
if
(
iv
->
var
->
name_hint
==
"threadIdx.z"
||
iv
->
thread_tag
==
"threadIdx.z"
)
{
}
else
if
(
iv
->
var
->
name_hint
==
"threadIdx.z"
||
iv
->
thread_tag
==
"threadIdx.z"
)
{
threadIdx_z_ext
=
op
->
value
;
}
}
StmtVisitor
::
VisitStmt_
(
op
);
}
public:
public:
PrimExpr
threadIdx_x_ext
=
Integer
(
1
);
PrimExpr
threadIdx_y_ext
=
Integer
(
1
);
PrimExpr
threadIdx_z_ext
=
Integer
(
1
);
};
void
CodeGenTileLangHIP
::
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangHIP
::
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
{
LaunchConfigExtractor
extractor
;
extractor
(
f
->
body
);
arith
::
Analyzer
analyzer
;
PrimExpr
threadIdx_ext
=
analyzer
.
Simplify
(
extractor
.
threadIdx_x_ext
*
extractor
.
threadIdx_y_ext
*
extractor
.
threadIdx_z_ext
);
if
(
const
IntImmNode
*
const
threadIdx_ext_int
=
threadIdx_ext
.
as
<
IntImmNode
>
())
{
PrimExpr
threadIdx_ext
=
analyzer
.
Simplify
(
extractor
.
threadIdx_x_ext
*
extractor
.
threadIdx_y_ext
*
extractor
.
threadIdx_z_ext
);
if
(
const
IntImmNode
*
const
threadIdx_ext_int
=
threadIdx_ext
.
as
<
IntImmNode
>
())
{
if
(
threadIdx_ext_int
->
value
==
1
)
{
// unable to extract the number of threads per block, hence directly return
// unable to extract the number of threads per block, hence directly
// return
return
;
}
stream
<<
" __launch_bounds__("
<<
threadIdx_ext_int
->
value
<<
")"
;
...
...
@@ -108,19 +116,20 @@ std::string CodeGenTileLangHIP::Finish() {
return
CodeGenC
::
Finish
();
}
void
CodeGenTileLangHIP
::
VisitStmt_
(
const
tir
::
ForNode
*
op
)
{
void
CodeGenTileLangHIP
::
VisitStmt_
(
const
tir
::
ForNode
*
op
)
{
if
(
op
->
kind
==
tir
::
ForKind
::
kUnrolled
)
{
PrintIndent
();
stream
<<
"#pragma unroll
\n
"
;
}
std
::
string
extent
=
PrintExpr
(
arith
::
Analyzer
().
Simplify
(
op
->
extent
+
op
->
min
));
std
::
string
extent
=
PrintExpr
(
arith
::
Analyzer
().
Simplify
(
op
->
extent
+
op
->
min
));
PrintIndent
();
std
::
string
vid
=
AllocVarID
(
op
->
loop_var
.
get
());
std
::
string
start
=
PrintExpr
(
op
->
min
);
stream
<<
"for ("
;
PrintType
(
op
->
loop_var
.
dtype
(),
stream
);
stream
<<
' '
<<
vid
<<
" = "
<<
start
<<
"; "
<<
vid
<<
" < "
<<
extent
<<
"; ++"
<<
vid
<<
") {
\n
"
;
stream
<<
' '
<<
vid
<<
" = "
<<
start
<<
"; "
<<
vid
<<
" < "
<<
extent
<<
"; ++"
<<
vid
<<
") {
\n
"
;
int
for_scope
=
BeginScope
();
PrintStmt
(
op
->
body
);
this
->
EndScope
(
for_scope
);
...
...
@@ -128,12 +137,13 @@ void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode* op) {
stream
<<
"}
\n
"
;
}
void
CodeGenTileLangHIP
::
BindThreadIndex
(
const
IterVar
&
iv
)
{
void
CodeGenTileLangHIP
::
BindThreadIndex
(
const
IterVar
&
iv
)
{
ICHECK
(
!
var_idmap_
.
count
(
iv
->
var
.
get
()));
var_idmap_
[
iv
->
var
.
get
()]
=
CastFromTo
(
iv
->
thread_tag
,
DataType
::
UInt
(
32
),
iv
->
var
.
dtype
());
var_idmap_
[
iv
->
var
.
get
()]
=
CastFromTo
(
iv
->
thread_tag
,
DataType
::
UInt
(
32
),
iv
->
var
.
dtype
());
}
void
CodeGenTileLangHIP
::
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangHIP
::
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
{
// NOLINT(*)
int
lanes
=
t
.
lanes
();
if
(
t
.
is_handle
())
{
ICHECK
(
t
.
is_scalar
())
<<
"do not yet support vector types"
;
...
...
@@ -154,51 +164,54 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
bool
fail
=
false
;
if
(
t
.
is_float
())
{
switch
(
t
.
bits
())
{
case
16
:
if
(
t
.
is_scalar
())
{
os
<<
"half_t"
;
}
else
if
(
lanes
<=
8
)
{
// Emit CUDA code to access fp16 vector elements.
//
// half4 is stored as uint2
//
// h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for half type"
;
os
<<
"uint"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
}
break
;
case
32
:
if
(
lanes
<=
4
)
{
os
<<
"float"
;
}
else
if
(
lanes
<=
8
)
{
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for float type with lanes > 4"
;
os
<<
"ulonglong"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
}
break
;
case
64
:
os
<<
"double"
;
break
;
default:
case
16
:
if
(
t
.
is_scalar
())
{
os
<<
"half_t"
;
}
else
if
(
lanes
<=
8
)
{
// Emit CUDA code to access fp16 vector elements.
//
// half4 is stored as uint2
//
// h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for half type"
;
os
<<
"uint"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
break
;
}
break
;
case
32
:
if
(
lanes
<=
4
)
{
os
<<
"float"
;
}
else
if
(
lanes
<=
8
)
{
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for float type with lanes > 4"
;
os
<<
"ulonglong"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
}
break
;
case
64
:
os
<<
"double"
;
break
;
default:
fail
=
true
;
break
;
}
if
(
!
fail
&&
(
t
.
is_scalar
()
||
t
.
bits
()
==
16
))
return
;
if
(
!
fail
&&
(
lanes
>
4
&&
lanes
<=
8
&&
t
.
bits
()
==
32
))
return
;
if
(
!
fail
&&
(
t
.
is_scalar
()
||
t
.
bits
()
==
16
))
return
;
if
(
!
fail
&&
(
lanes
>
4
&&
lanes
<=
8
&&
t
.
bits
()
==
32
))
return
;
if
(
!
fail
&&
(
lanes
>=
2
&&
lanes
<=
4
))
{
os
<<
lanes
;
return
;
...
...
@@ -212,18 +225,21 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
else
{
fail
=
true
;
}
if
(
!
fail
)
return
;
if
(
!
fail
)
return
;
}
else
if
(
t
.
is_float8
())
{
if
(
t
.
is_scalar
())
{
os
<<
"unsigned char"
;
// __nv_fp8_storage_t is an alias of unsigned char
os
<<
"unsigned char"
;
// __nv_fp8_storage_t is an alias of unsigned char
}
else
if
(
lanes
==
2
)
{
os
<<
"unsigned short int"
;
// __nv_fp8x2_storage_t is an alias of unsigned short
os
<<
"unsigned short int"
;
// __nv_fp8x2_storage_t is an alias of
// unsigned short
}
else
if
(
lanes
==
4
)
{
os
<<
"unsigned int"
;
// __nv_fp8x4_storage_t is an alias of unsigned int
os
<<
"unsigned int"
;
// __nv_fp8x4_storage_t is an alias of unsigned int
}
else
{
fail
=
true
;
}
if
(
!
fail
)
return
;
if
(
!
fail
)
return
;
}
else
if
(
t
==
DataType
::
Bool
())
{
os
<<
"bool"
;
return
;
...
...
@@ -240,133 +256,135 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
os
<<
"u"
;
}
switch
(
t
.
bits
())
{
case
1
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
os
<<
"int8_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int16_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
32
)
{
os
<<
"int"
;
return
;
}
else
{
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type!"
;
}
}
case
4
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
4
)
{
os
<<
"int16_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
// directly 8 4-bit int in integer.
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int2"
;
return
;
}
else
if
(
t
.
lanes
()
==
32
)
{
os
<<
"int4"
;
return
;
}
else
if
(
t
.
lanes
()
==
64
)
{
os
<<
"int8"
;
return
;
}
else
{
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type!"
;
}
case
1
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
os
<<
"int8_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int16_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
32
)
{
os
<<
"int"
;
return
;
}
else
{
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type!"
;
}
case
8
:
{
if
(
t
.
lanes
()
==
4
)
{
// directly 4 8 bit int in integer.
// We use int for int8x4 instead of char4 because using char4 is
// likely to produce extra instructions to pack four int8 elements
// into 32-bit data.
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
os
<<
"int2"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int4"
;
return
;
}
else
if
(
!
t
.
is_uint
()
&&
t
.
is_scalar
())
{
os
<<
"signed char"
;
break
;
}
else
{
os
<<
"char"
;
break
;
}
}
case
4
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
4
)
{
os
<<
"int16_t"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
// directly 8 4-bit int in integer.
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int2"
;
return
;
}
else
if
(
t
.
lanes
()
==
32
)
{
os
<<
"int4"
;
return
;
}
else
if
(
t
.
lanes
()
==
64
)
{
os
<<
"int8"
;
return
;
}
else
{
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type!"
;
}
case
16
:
{
if
(
t
.
is_scalar
())
{
os
<<
"short"
;
}
else
if
(
t
.
lanes
()
<=
4
)
{
os
<<
"short"
<<
lanes
;
}
else
if
(
t
.
lanes
()
<=
8
)
{
// Emit CUDA code to access int16 vector elements.
//
// short4 is stored as int2
//
// s4.x is emitted as *(short2*)(&(i2.x)).x
// s4.y is emitted as *(short2*)(&(i2.x)).y
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ
(
t
.
lanes
()
%
2
,
0
)
<<
"only support even lane for shorT type with lanes > 4"
;
os
<<
"int"
<<
t
.
lanes
()
/
2
;
}
else
{
fail
=
true
;
}
if
(
!
fail
)
{
return
;
}
}
case
8
:
{
if
(
t
.
lanes
()
==
4
)
{
// directly 4 8 bit int in integer.
// We use int for int8x4 instead of char4 because using char4 is
// likely to produce extra instructions to pack four int8 elements
// into 32-bit data.
os
<<
"int"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
os
<<
"int2"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
os
<<
"int4"
;
return
;
}
else
if
(
!
t
.
is_uint
()
&&
t
.
is_scalar
())
{
os
<<
"signed char"
;
break
;
}
case
32
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
}
else
if
(
t
.
lanes
()
<=
4
)
{
os
<<
"int"
<<
t
.
lanes
();
}
else
if
(
t
.
lanes
()
<=
8
)
{
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for int32 type with lanes > 4"
;
os
<<
"longlong"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
}
if
(
!
fail
)
{
return
;
}
}
else
{
os
<<
"char"
;
break
;
}
case
64
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int64_t"
;
}
else
if
(
t
.
lanes
()
==
2
)
{
os
<<
"longlong2"
;
}
else
if
(
t
.
lanes
()
==
3
)
{
os
<<
"longlong3"
;
}
else
if
(
t
.
lanes
()
==
4
)
{
os
<<
"longlong4"
;
}
}
case
16
:
{
if
(
t
.
is_scalar
())
{
os
<<
"short"
;
}
else
if
(
t
.
lanes
()
<=
4
)
{
os
<<
"short"
<<
lanes
;
}
else
if
(
t
.
lanes
()
<=
8
)
{
// Emit CUDA code to access int16 vector elements.
//
// short4 is stored as int2
//
// s4.x is emitted as *(short2*)(&(i2.x)).x
// s4.y is emitted as *(short2*)(&(i2.x)).y
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ
(
t
.
lanes
()
%
2
,
0
)
<<
"only support even lane for shorT type with lanes > 4"
;
os
<<
"int"
<<
t
.
lanes
()
/
2
;
}
else
{
fail
=
true
;
}
if
(
!
fail
)
{
return
;
}
default:
break
;
}
case
32
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
}
else
if
(
t
.
lanes
()
<=
4
)
{
os
<<
"int"
<<
t
.
lanes
();
}
else
if
(
t
.
lanes
()
<=
8
)
{
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ
(
lanes
%
2
,
0
)
<<
"only support even lane for int32 type with lanes > 4"
;
os
<<
"longlong"
<<
lanes
/
2
;
}
else
{
fail
=
true
;
break
;
}
if
(
!
fail
)
{
return
;
}
break
;
}
case
64
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int64_t"
;
}
else
if
(
t
.
lanes
()
==
2
)
{
os
<<
"longlong2"
;
}
else
if
(
t
.
lanes
()
==
3
)
{
os
<<
"longlong3"
;
}
else
if
(
t
.
lanes
()
==
4
)
{
os
<<
"longlong4"
;
}
return
;
}
default:
fail
=
true
;
break
;
}
if
(
!
fail
&&
lanes
==
1
)
{
return
;
...
...
@@ -379,8 +397,9 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type"
;
}
void
CodeGenTileLangHIP
::
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
PrimExpr
rhs
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangHIP
::
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
PrimExpr
rhs
,
std
::
ostream
&
os
)
{
// NOLINT(*)
// Declare the result.
std
::
string
sret
=
name_supply_
->
FreshName
(
"_"
);
this
->
PrintIndent
();
...
...
@@ -414,15 +433,18 @@ void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string& op, DataType t, Pri
os
<<
sret
;
}
void
CodeGenTileLangHIP
::
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangHIP
::
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
ostream
&
os
)
{
// NOLINT(*)
if
(
t
.
is_scalar
())
{
os
<<
vec
;
return
;
}
static
const
char
access
[]
=
{
'x'
,
'y'
,
'z'
,
'w'
};
ICHECK
(
i
>=
0
&&
i
<
(
t
.
bits
()
==
8
?
16
:
(
t
.
bits
()
==
16
||
t
.
bits
()
==
32
)
?
8
:
4
));
ICHECK
(
i
>=
0
&&
i
<
(
t
.
bits
()
==
8
?
16
:
(
t
.
bits
()
==
16
||
t
.
bits
()
==
32
)
?
8
:
4
));
if
(
t
.
bits
()
==
8
&&
(
t
.
is_int
()
||
t
.
is_uint
()))
{
std
::
string
type_name
=
t
.
is_int
()
?
"char"
:
"unsigned char"
;
if
(
t
.
lanes
()
==
2
||
t
.
lanes
()
==
3
)
{
...
...
@@ -432,9 +454,11 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, in
os
<<
"(("
<<
type_name
<<
")("
<<
ac
<<
" >> "
<<
i
%
4
*
8
<<
"))"
;
}
}
else
if
(
t
.
is_float16
())
{
os
<<
"((half2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
os
<<
"((half2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
}
else
if
(
t
.
is_bfloat16
())
{
os
<<
"((nv_bfloat162*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
os
<<
"((nv_bfloat162*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
}
else
if
(
t
.
lanes
()
>
4
&&
t
.
lanes
()
<=
8
)
{
std
::
string
type_name
;
if
(
t
.
bits
()
==
16
)
{
...
...
@@ -453,20 +477,24 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, in
}
}
ICHECK
(
!
type_name
.
empty
());
os
<<
"(("
<<
type_name
<<
"2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
os
<<
"(("
<<
type_name
<<
"2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
];
}
else
{
os
<<
vec
<<
"."
<<
access
[
i
];
}
}
void
CodeGenTileLangHIP
::
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
const
std
::
string
&
value
)
{
void
CodeGenTileLangHIP
::
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
const
std
::
string
&
value
)
{
this
->
PrintIndent
();
static
const
char
access
[]
=
{
'x'
,
'y'
,
'z'
,
'w'
};
ICHECK
(
i
>=
0
&&
i
<
(
t
.
bits
()
==
8
?
16
:
(
t
.
bits
()
==
16
||
t
.
bits
()
==
32
)
?
8
:
4
));
ICHECK
(
i
>=
0
&&
i
<
(
t
.
bits
()
==
8
?
16
:
(
t
.
bits
()
==
16
||
t
.
bits
()
==
32
)
?
8
:
4
));
if
(
t
.
bits
()
==
8
&&
(
t
.
is_int
()
||
t
.
is_uint
()))
{
if
(
t
.
lanes
()
==
2
||
t
.
lanes
()
==
3
)
{
stream
<<
vec
<<
'.'
<<
access
[
i
%
t
.
lanes
()]
<<
"="
<<
"("
<<
value
<<
");
\n
"
;
stream
<<
vec
<<
'.'
<<
access
[
i
%
t
.
lanes
()]
<<
"="
<<
"("
<<
value
<<
");
\n
"
;
}
else
{
std
::
string
ac
=
t
.
lanes
()
==
4
?
vec
:
(
vec
+
"."
+
access
[
i
/
4
]);
stream
<<
ac
<<
"="
;
...
...
@@ -477,11 +505,11 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, i
stream
<<
"("
<<
value
<<
" << "
<<
i
%
4
*
8
<<
");
\n
"
;
}
}
else
if
(
t
.
is_float16
())
{
stream
<<
"((half2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
stream
<<
"((half2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
}
else
if
(
t
.
is_bfloat16
())
{
stream
<<
"((nv_bfloat162*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
stream
<<
"((nv_bfloat162*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
}
else
if
(
t
.
lanes
()
>
4
&&
t
.
lanes
()
<=
8
)
{
std
::
string
type_name
;
if
(
t
.
bits
()
==
16
)
{
...
...
@@ -500,15 +528,15 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, i
}
}
ICHECK
(
!
type_name
.
empty
());
stream
<<
"(("
<<
type_name
<<
"2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
stream
<<
"(("
<<
type_name
<<
"2*)(&("
<<
vec
<<
"."
<<
access
[
i
/
2
]
<<
")))->"
<<
access
[
i
%
2
]
<<
" = "
<<
value
<<
";
\n
"
;
}
else
{
stream
<<
vec
<<
"."
<<
access
[
i
]
<<
" = "
<<
value
<<
";
\n
"
;
}
}
void
CodeGenTileLangHIP
::
PrintStorageSync
(
const
CallNode
*
op
)
{
const
std
::
string
&
sync
=
op
->
args
[
0
].
as
<
StringImmNode
>
()
->
value
;
void
CodeGenTileLangHIP
::
PrintStorageSync
(
const
CallNode
*
op
)
{
const
std
::
string
&
sync
=
op
->
args
[
0
].
as
<
StringImmNode
>
()
->
value
;
if
(
sync
==
"warp"
)
{
// DO nothing.
}
else
if
(
sync
==
"shared"
||
sync
==
"shared.dyn"
)
{
...
...
@@ -517,9 +545,11 @@ void CodeGenTileLangHIP::PrintStorageSync(const CallNode* op) {
}
}
void
CodeGenTileLangHIP
::
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
{
// NOLINT(*)
ICHECK_NE
(
scope
,
"global"
)
<<
"Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead"
;
void
CodeGenTileLangHIP
::
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
{
// NOLINT(*)
ICHECK_NE
(
scope
,
"global"
)
<<
"Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead"
;
if
(
scope
==
"shared"
)
{
os
<<
"__shared__ "
;
}
else
if
(
scope
==
"shared.dyn"
)
{
...
...
@@ -527,13 +557,16 @@ void CodeGenTileLangHIP::PrintStorageScope(const std::string& scope, std::ostrea
}
}
std
::
string
CodeGenTileLangHIP
::
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
{
if
(
from
==
target
)
return
value
;
std
::
string
CodeGenTileLangHIP
::
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
{
if
(
from
==
target
)
return
value
;
std
::
ostringstream
os
;
os
<<
"(("
;
this
->
PrintType
(
target
,
os
);
os
<<
")"
;
if
(
from
.
is_float16
()
&&
(
target
.
is_int
()
||
target
.
is_uint
())
&&
target
.
bits
()
==
8
)
{
if
(
from
.
is_float16
()
&&
(
target
.
is_int
()
||
target
.
is_uint
())
&&
target
.
bits
()
==
8
)
{
os
<<
"("
;
if
(
target
.
is_uint
())
{
os
<<
"u"
;
...
...
@@ -544,13 +577,14 @@ std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from, Dat
return
os
.
str
();
}
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
{
DataType
from_ty
=
op
->
value
.
dtype
();
DataType
target_ty
=
op
->
dtype
;
ICHECK_EQ
(
target_ty
.
lanes
(),
from_ty
.
lanes
());
// Emit simple C-style type conversion.
if
(
from_ty
.
is_scalar
())
return
CodeGenC
::
VisitExpr_
(
op
,
os
);
if
(
from_ty
.
is_scalar
())
return
CodeGenC
::
VisitExpr_
(
op
,
os
);
// We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops.
...
...
@@ -573,8 +607,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CastNode* op, std::ostream& os) {
os
<<
sret
;
}
void
CodeGenTileLangHIP
::
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangHIP
::
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
{
// NOLINT(*)
DataType
ret_dtype
=
GetRuntimeDataType
(
ret_type
);
if
(
ret_dtype
.
is_vector
())
{
//
...
...
@@ -614,7 +650,8 @@ void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, co
std
::
ostringstream
scall
;
scall
<<
global_symbol
<<
"("
;
for
(
size_t
j
=
0
;
j
<
sargs
.
size
();
++
j
)
{
if
(
j
>
0
)
scall
<<
", "
;
if
(
j
>
0
)
scall
<<
", "
;
PrintVecElemLoad
(
sargs
[
j
],
args
[
arg_begin
+
j
].
dtype
(),
i
,
scall
);
}
scall
<<
")"
;
...
...
@@ -623,13 +660,16 @@ void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, co
}
os
<<
sret
;
}
else
{
CodeGenC
::
PrintCallExtern
(
ret_type
,
global_symbol
,
args
,
skip_first_arg
,
os
);
CodeGenC
::
PrintCallExtern
(
ret_type
,
global_symbol
,
args
,
skip_first_arg
,
os
);
}
}
// Print a reference expression to a buffer.
std
::
string
CodeGenTileLangHIP
::
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
{
const
VarNode
*
buffer_var
=
buffer
->
data
.
get
();
std
::
string
CodeGenTileLangHIP
::
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
{
const
VarNode
*
buffer_var
=
buffer
->
data
.
get
();
std
::
ostringstream
os
;
std
::
string
vid
=
GetVarID
(
buffer_var
);
std
::
string
scope
;
...
...
@@ -685,12 +725,13 @@ std::string CodeGenTileLangHIP::GetBufferRef(DataType t, const BufferNode* buffe
return
os
.
str
();
}
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
{
auto
print_extern_call_stmt
=
[
&
](
std
::
string
name
,
size_t
offset
=
0
)
{
this
->
PrintIndent
();
this
->
stream
<<
name
<<
"("
;
for
(
size_t
i
=
offset
;
i
<
op
->
args
.
size
();
i
++
)
{
if
(
i
>
offset
)
this
->
stream
<<
", "
;
if
(
i
>
offset
)
this
->
stream
<<
", "
;
this
->
stream
<<
this
->
PrintExpr
(
op
->
args
[
i
]);
}
this
->
stream
<<
");
\n
"
;
...
...
@@ -701,16 +742,18 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
std
::
string
src
=
this
->
PrintExpr
(
op
->
args
[
2
]);
std
::
string
src_offset
=
this
->
PrintExpr
(
op
->
args
[
3
]);
std
::
string
size
=
this
->
PrintExpr
(
op
->
args
[
4
]);
// use size of argument list to indicate whether or not to use predicated cp.async
// use size of argument list to indicate whether or not to use predicated
// cp.async
if
(
op
->
args
.
size
()
==
5
)
{
this
->
PrintIndent
();
this
->
stream
<<
"tl::cp_async_gs<"
<<
size
<<
">("
<<
dst
<<
"+"
<<
dst_offset
<<
", "
<<
src
<<
"+"
<<
src_offset
<<
");
\n
"
;
this
->
stream
<<
"tl::cp_async_gs<"
<<
size
<<
">("
<<
dst
<<
"+"
<<
dst_offset
<<
", "
<<
src
<<
"+"
<<
src_offset
<<
");
\n
"
;
}
else
{
std
::
string
condition
=
this
->
PrintExpr
(
op
->
args
[
5
]);
this
->
PrintIndent
();
this
->
stream
<<
"tl::cp_async_gs_conditional<"
<<
size
<<
">("
<<
dst
<<
"+"
<<
dst_offset
<<
", "
<<
src
<<
"+"
<<
src_offset
<<
", "
<<
condition
<<
");
\n
"
;
this
->
stream
<<
"tl::cp_async_gs_conditional<"
<<
size
<<
">("
<<
dst
<<
"+"
<<
dst_offset
<<
", "
<<
src
<<
"+"
<<
src_offset
<<
", "
<<
condition
<<
");
\n
"
;
}
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_commit_group
()))
{
print_extern_call_stmt
(
"tl::cp_async_commit"
);
...
...
@@ -722,7 +765,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this
->
PrintIndent
();
int
barrier_count
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
barrier_name
=
"_mbarrier"
;
this
->
stream
<<
"__shared__ uint64_t "
<<
barrier_name
<<
"["
<<
barrier_count
<<
"];
\n
"
;
this
->
stream
<<
"__shared__ uint64_t "
<<
barrier_name
<<
"["
<<
barrier_count
<<
"];
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
GetMBarrierOp
()))
{
std
::
string
barrier_name
=
"_mbarrier"
;
std
::
string
barrier_id
=
this
->
PrintExpr
(
op
->
args
[
0
]);
...
...
@@ -751,13 +795,15 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
int
trans
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
int
num
=
Downcast
<
IntImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
func_name
=
"tl::ptx_ldmatrix_x"
+
std
::
to_string
(
num
);
if
(
trans
==
1
)
func_name
+=
"_trans"
;
if
(
trans
==
1
)
func_name
+=
"_trans"
;
print_extern_call_stmt
(
func_name
,
2
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
STMatrixOp
()))
{
int
trans
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
int
num
=
Downcast
<
IntImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
func_name
=
"tl::ptx_stmatrix_x"
+
std
::
to_string
(
num
);
if
(
trans
==
1
)
func_name
+=
"_trans"
;
if
(
trans
==
1
)
func_name
+=
"_trans"
;
print_extern_call_stmt
(
func_name
,
2
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
FenceProxyAsyncOp
()))
{
print_extern_call_stmt
(
"tl::fence_proxy_async"
);
...
...
@@ -765,15 +811,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this
->
PrintIndent
();
int
nreg
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
int
is_inc
=
Downcast
<
IntImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
func_name
=
is_inc
?
"tl::warpgroup_reg_alloc"
:
"tl::warpgroup_reg_dealloc"
;
std
::
string
func_name
=
is_inc
?
"tl::warpgroup_reg_alloc"
:
"tl::warpgroup_reg_dealloc"
;
this
->
stream
<<
func_name
<<
"<"
<<
std
::
to_string
(
nreg
)
<<
">();
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
WaitWgmma
()))
{
this
->
PrintIndent
();
int
num_mma
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
this
->
stream
<<
"tl::wait_wgmma<"
<<
std
::
to_string
(
num_mma
)
<<
">();
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
PackB16Op
()))
{
os
<<
"__pack_half2("
<<
this
->
PrintExpr
(
op
->
args
[
0
])
<<
", "
<<
this
->
PrintExpr
(
op
->
args
[
1
])
<<
")"
;
os
<<
"__pack_half2("
<<
this
->
PrintExpr
(
op
->
args
[
0
])
<<
", "
<<
this
->
PrintExpr
(
op
->
args
[
1
])
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
tvm_fill_fragment
()))
{
need_mma_h_
=
true
;
ICHECK_EQ
(
op
->
args
.
size
(),
6U
);
...
...
@@ -807,7 +854,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this
->
PrintExpr
(
op
->
args
[
4
],
os
);
os
<<
"], "
;
this
->
PrintExpr
(
op
->
args
[
6
],
os
);
if
(
const
StringImmNode
*
str
=
op
->
args
[
7
].
as
<
StringImmNode
>
())
{
if
(
const
StringImmNode
*
str
=
op
->
args
[
7
].
as
<
StringImmNode
>
())
{
os
<<
", nvcuda::wmma::mem_"
<<
str
->
value
;
}
else
{
LOG
(
FATAL
)
<<
"Invalid parameters"
;
...
...
@@ -833,7 +880,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this
->
PrintExpr
(
op
->
args
[
i
*
2
+
1
],
os
);
os
<<
"]"
<<
((
i
<
3
)
?
", "
:
")"
);
}
}
else
if
(
op
->
op
.
same_as
(
builtin
::
tvm_mfma
()))
{
}
else
if
(
op
->
op
.
same_as
(
builtin
::
tvm_mfma
()))
{
// arg 0: prefix: {otype}_16x16x16{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
...
...
@@ -847,7 +894,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
// arg 10: C accumulator
// arg 11: C accumulator index
ICHECK
(
op
->
args
.
size
()
==
12U
)
<<
"Invalid number of arguments for tvm_mfma"
;
ICHECK
(
op
->
args
.
size
()
==
12U
)
<<
"Invalid number of arguments for tvm_mfma"
;
std
::
string
prefix
=
Downcast
<
StringImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
A_layout
=
Downcast
<
StringImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
B_layout
=
Downcast
<
StringImm
>
(
op
->
args
[
2
])
->
value
;
...
...
@@ -860,7 +908,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
std
::
string
b_bias
=
this
->
PrintExpr
(
op
->
args
[
9
]);
std
::
string
c_ref
=
this
->
PrintExpr
(
op
->
args
[
10
]);
std
::
string
c_bias
=
this
->
PrintExpr
(
op
->
args
[
11
]);
ICHECK
(
A_layout
==
"row"
||
B_layout
==
"row"
)
<<
"Matrix core only support row major"
;
ICHECK
(
A_layout
==
"row"
||
B_layout
==
"row"
)
<<
"Matrix core only support row major"
;
// map for dtype -> float32x4 -> float4
std
::
unordered_map
<
std
::
string
,
std
::
string
>
dtype_map
=
{
{
"int8"
,
"char"
},
...
...
@@ -873,8 +922,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
{
"float16x4"
,
"float16x4"
},
{
"bfloat16x4"
,
"bfloat16x4"
},
{
"float32x4"
,
"float32x4"
},
{
"float32x16"
,
"float32x16"
}
};
{
"float32x16"
,
"float32x16"
}};
std
::
string
call_mfma_code
=
R"({
*((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}),
*((({B_dytpe}*){b_ref}) + {b_bias}),
...
...
@@ -893,15 +941,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
replacer
.
register_rule
(
"{c_ref}"
,
c_ref
);
replacer
.
register_rule
(
"{c_bias}"
,
c_bias
);
os
<<
replacer
.
rewrite
(
call_mfma_code
);
}
else
{
}
else
{
CodeGenC
::
VisitExpr_
(
op
,
os
);
}
}
void
CodeGenTileLangHIP
::
VisitStmt_
(
const
AttrStmtNode
*
op
)
{
void
CodeGenTileLangHIP
::
VisitStmt_
(
const
AttrStmtNode
*
op
)
{
if
(
op
->
attr_key
==
tir
::
attr
::
async_commit_queue_scope
)
{
const
IntImmNode
*
queue_id
=
op
->
value
.
as
<
IntImmNode
>
();
ICHECK
(
queue_id
&&
queue_id
->
value
==
0
)
<<
"For CUDA, the index of an async queue must be 0."
;
const
IntImmNode
*
queue_id
=
op
->
value
.
as
<
IntImmNode
>
();
ICHECK
(
queue_id
&&
queue_id
->
value
==
0
)
<<
"For CUDA, the index of an async queue must be 0."
;
this
->
VisitStmt
(
op
->
body
);
auto
commit_group
=
Call
(
DataType
::
Void
(),
builtin
::
ptx_commit_group
(),
{});
this
->
VisitExpr
(
commit_group
,
this
->
stream
);
...
...
@@ -909,9 +958,11 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
}
else
if
(
op
->
attr_key
==
tir
::
attr
::
async_wait_queue_scope
)
{
auto
wait_attrs
=
GetAsyncWaitAttributes
(
op
);
auto
queue_id
=
wait_attrs
.
first
.
as
<
IntImmNode
>
();
ICHECK
(
queue_id
&&
queue_id
->
value
==
0
)
<<
"For CUDA, the index of an async queue must be 0."
;
ICHECK
(
queue_id
&&
queue_id
->
value
==
0
)
<<
"For CUDA, the index of an async queue must be 0."
;
auto
wait_cnt
=
wait_attrs
.
second
;
auto
wait_group
=
Call
(
DataType
::
Void
(),
builtin
::
ptx_wait_group
(),
{
wait_cnt
});
auto
wait_group
=
Call
(
DataType
::
Void
(),
builtin
::
ptx_wait_group
(),
{
wait_cnt
});
this
->
VisitExpr
(
wait_group
,
this
->
stream
);
auto
inner
=
op
->
body
.
as
<
AttrStmtNode
>
();
ICHECK
(
inner
);
...
...
@@ -919,7 +970,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
return
;
}
else
if
(
op
->
attr_key
==
"threadblock_swizzle_pattern"
)
{
this
->
PrintIndent
();
const
StringImmNode
*
pattern
=
op
->
value
.
as
<
StringImmNode
>
();
const
StringImmNode
*
pattern
=
op
->
value
.
as
<
StringImmNode
>
();
ICHECK
(
pattern
);
this
->
stream
<<
"const dim3 blockIdx = "
<<
pattern
->
value
<<
"();
\n
"
;
this
->
VisitStmt
(
op
->
body
);
...
...
@@ -928,7 +979,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
CodeGenC
::
VisitStmt_
(
op
);
}
void
CodeGenTileLangHIP
::
VisitStmt_
(
const
AllocateNode
*
op
)
{
void
CodeGenTileLangHIP
::
VisitStmt_
(
const
AllocateNode
*
op
)
{
ICHECK
(
!
is_zero
(
op
->
condition
));
std
::
string
vid
=
AllocVarID
(
op
->
buffer_var
.
get
());
...
...
@@ -941,7 +992,8 @@ void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) {
stream
<<
' '
<<
vid
<<
"[];
\n
"
;
}
else
{
size_t
constant_size
=
op
->
ConstantAllocationSize
();
ICHECK_GT
(
constant_size
,
0
)
<<
"Can only handle constant size stack allocation for now"
;
ICHECK_GT
(
constant_size
,
0
)
<<
"Can only handle constant size stack allocation for now"
;
if
((
op
->
dtype
==
DataType
::
Int
(
4
)
||
op
->
dtype
==
DataType
::
UInt
(
4
)
||
op
->
dtype
==
DataType
::
Int
(
1
))
&&
...
...
@@ -955,7 +1007,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) {
this
->
PrintStmt
(
op
->
body
);
}
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
{
int
lanes
=
static_cast
<
int
>
(
Downcast
<
IntImm
>
(
op
->
lanes
)
->
value
);
CHECK_LE
(
lanes
,
4
)
<<
"ValueError: Ramp of more than 4 lanes is not allowed."
;
os
<<
"(make_"
;
...
...
@@ -964,16 +1016,19 @@ void CodeGenTileLangHIP::VisitExpr_(const RampNode* op, std::ostream& os) {
for
(
int
i
=
0
;
i
<
lanes
;
i
++
)
{
os
<<
"("
<<
PrintExpr
(
op
->
base
)
<<
")"
<<
"+("
<<
PrintExpr
(
op
->
stride
)
<<
"*"
<<
i
<<
")"
;
if
(
i
!=
lanes
-
1
)
os
<<
", "
;
if
(
i
!=
lanes
-
1
)
os
<<
", "
;
}
os
<<
"))"
;
}
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
int
lanes
=
static_cast
<
int
>
(
Downcast
<
IntImm
>
(
op
->
lanes
)
->
value
);
if
((
op
->
dtype
.
is_int
()
||
op
->
dtype
.
is_uint
())
&&
op
->
dtype
.
bits
()
==
8
&&
lanes
==
4
)
{
if
((
op
->
dtype
.
is_int
()
||
op
->
dtype
.
is_uint
())
&&
op
->
dtype
.
bits
()
==
8
&&
lanes
==
4
)
{
// make_int8x4
const
int64_t
*
p
=
as_const_int
(
op
->
value
);
const
int64_t
*
p
=
as_const_int
(
op
->
value
);
ICHECK
(
p
);
int64_t
v
=
*
p
&
0xFF
;
v
=
(
v
<<
24
)
|
(
v
<<
16
)
|
(
v
<<
8
)
|
v
;
...
...
@@ -991,7 +1046,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType
(
op
->
dtype
,
os
);
os
<<
'('
;
for
(
int
i
=
0
;
i
<
lanes
/
2
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
if
(
i
!=
0
)
os
<<
", "
;
os
<<
"__pack_half2("
<<
v
<<
", "
<<
v
<<
")"
;
}
os
<<
')'
;
...
...
@@ -1004,18 +1060,21 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType
(
op
->
dtype
,
os
);
os
<<
'('
;
for
(
int
i
=
0
;
i
<
lanes
/
2
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
if
(
i
!=
0
)
os
<<
", "
;
os
<<
"__pack_nv_bfloat162("
<<
v
<<
", "
<<
v
<<
")"
;
}
os
<<
')'
;
return
;
}
if
(
op
->
dtype
.
is_float
()
&&
op
->
dtype
.
bits
()
==
32
&&
op
->
dtype
.
lanes
()
==
8
)
{
if
(
op
->
dtype
.
is_float
()
&&
op
->
dtype
.
bits
()
==
32
&&
op
->
dtype
.
lanes
()
==
8
)
{
std
::
string
v
=
PrintExpr
(
op
->
value
);
os
<<
"make_ulonglong4("
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
if
(
i
!=
0
)
os
<<
", "
;
os
<<
"*(unsigned long long*)&make_float2("
<<
v
<<
", "
<<
v
<<
")"
;
}
os
<<
')'
;
...
...
@@ -1024,7 +1083,7 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
if
((
op
->
dtype
.
is_int
()
||
op
->
dtype
.
is_uint
())
&&
op
->
dtype
.
bits
()
==
4
)
{
bool
fail
=
false
;
const
int64_t
*
p
=
as_const_int
(
op
->
value
);
const
int64_t
*
p
=
as_const_int
(
op
->
value
);
ICHECK
(
p
);
int64_t
v
=
*
p
&
0xF
;
...
...
@@ -1036,7 +1095,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
os
<<
"(int16_t)"
<<
v
;
}
}
else
{
v
=
(
v
<<
28
)
|
(
v
<<
24
)
|
(
v
<<
20
)
|
(
v
<<
16
)
|
(
v
<<
12
)
|
(
v
<<
8
)
|
(
v
<<
4
)
|
v
;
v
=
(
v
<<
28
)
|
(
v
<<
24
)
|
(
v
<<
20
)
|
(
v
<<
16
)
|
(
v
<<
12
)
|
(
v
<<
8
)
|
(
v
<<
4
)
|
v
;
if
(
lanes
==
8
)
{
if
(
op
->
dtype
.
is_uint
())
{
os
<<
"(uint)"
<<
v
;
...
...
@@ -1048,7 +1108,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType
(
op
->
dtype
,
os
);
os
<<
'('
;
for
(
int
i
=
0
;
i
<
lanes
/
8
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
if
(
i
!=
0
)
os
<<
", "
;
if
(
op
->
dtype
.
is_uint
())
{
os
<<
"(uint)"
<<
v
;
}
else
{
...
...
@@ -1071,13 +1132,15 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType
(
op
->
dtype
,
os
);
os
<<
'('
;
for
(
int
i
=
0
;
i
<
lanes
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
if
(
i
!=
0
)
os
<<
", "
;
os
<<
v
;
}
os
<<
')'
;
}
inline
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangHIP
*
p
)
{
// NOLINT(*)
inline
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangHIP
*
p
)
{
// NOLINT(*)
// Type code is kBFloat
if
(
op
->
dtype
.
is_bfloat16
())
{
os
<<
"bfloat16_t"
;
...
...
@@ -1086,46 +1149,50 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLang
}
// Type code is kFloat
switch
(
op
->
dtype
.
bits
())
{
case
64
:
case
32
:
{
std
::
ostringstream
temp
;
if
(
std
::
isinf
(
op
->
value
))
{
if
(
op
->
value
<
0
)
{
temp
<<
"-"
;
}
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"HIPRT_INF_F"
:
"HIPRT_INF"
);
}
else
if
(
std
::
isnan
(
op
->
value
))
{
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"HIPRT_NAN_F"
:
"HIPRT_NAN"
);
}
else
{
temp
<<
std
::
scientific
<<
op
->
value
;
if
(
op
->
dtype
.
bits
()
==
32
)
temp
<<
'f'
;
case
64
:
case
32
:
{
std
::
ostringstream
temp
;
if
(
std
::
isinf
(
op
->
value
))
{
if
(
op
->
value
<
0
)
{
temp
<<
"-"
;
}
p
->
MarkConst
(
temp
.
str
());
os
<<
temp
.
str
();
break
;
}
case
16
:
{
os
<<
"half_t"
<<
'('
;
FloatImm
const_f32
=
FloatImm
(
DataType
::
Float
(
32
),
op
->
value
);
PrintConst
(
const_f32
.
get
(),
os
,
p
);
os
<<
')'
;
break
;
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"HIPRT_INF_F"
:
"HIPRT_INF"
);
}
else
if
(
std
::
isnan
(
op
->
value
))
{
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"HIPRT_NAN_F"
:
"HIPRT_NAN"
);
}
else
{
temp
<<
std
::
scientific
<<
op
->
value
;
if
(
op
->
dtype
.
bits
()
==
32
)
temp
<<
'f'
;
}
default:
LOG
(
FATAL
)
<<
"Bad bit-width for float: "
<<
op
->
dtype
<<
"
\n
"
;
p
->
MarkConst
(
temp
.
str
());
os
<<
temp
.
str
();
break
;
}
case
16
:
{
os
<<
"half_t"
<<
'('
;
FloatImm
const_f32
=
FloatImm
(
DataType
::
Float
(
32
),
op
->
value
);
PrintConst
(
const_f32
.
get
(),
os
,
p
);
os
<<
')'
;
break
;
}
default:
LOG
(
FATAL
)
<<
"Bad bit-width for float: "
<<
op
->
dtype
<<
"
\n
"
;
}
}
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintConst
(
op
,
os
,
this
);
}
void
CodeGenTileLangHIP
::
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangHIP
::
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
{
// Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile.
//
if
((
op
->
dtype
.
is_float16
()
||
op
->
dtype
.
is_bfloat16
())
&&
IsVolatile
(
op
->
buffer
->
data
.
get
()))
{
if
((
op
->
dtype
.
is_float16
()
||
op
->
dtype
.
is_bfloat16
())
&&
IsVolatile
(
op
->
buffer
->
data
.
get
()))
{
os
<<
"("
;
PrintType
(
op
->
dtype
,
os
);
os
<<
")("
<<
value
<<
")"
;
...
...
@@ -1134,15 +1201,17 @@ void CodeGenTileLangHIP::HandleVolatileLoads(const std::string& value, const Buf
}
}
void
CodeGenTileLangHIP
::
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
{
void
CodeGenTileLangHIP
::
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
{
ICHECK_GT
(
t
.
lanes
(),
1
);
if
(
t
.
bits
()
==
8
&&
(
t
.
is_int
()
||
t
.
is_uint
()))
{
if
(
!
(
t
.
lanes
()
==
2
||
t
.
lanes
()
==
3
))
{
if
(
i
!=
0
)
{
os
<<
"|"
;
}
os
<<
"((0x000000ff << "
<<
i
*
8
<<
") & ("
<<
value
<<
" << "
<<
i
*
8
<<
"))"
;
os
<<
"((0x000000ff << "
<<
i
*
8
<<
") & ("
<<
value
<<
" << "
<<
i
*
8
<<
"))"
;
return
;
}
}
...
...
@@ -1199,7 +1268,7 @@ void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i, const std::stri
return
;
}
void
CodeGenTileLangHIP
::
AddFunction
(
const
PrimFunc
&
f
)
{
void
CodeGenTileLangHIP
::
AddFunction
(
const
PrimFunc
&
f
)
{
// clear previous generated state.
this
->
InitFuncState
(
f
);
// reserve keywords
...
...
@@ -1218,10 +1287,11 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) {
for
(
size_t
i
=
0
;
i
<
f
->
params
.
size
();
++
i
)
{
tir
::
Var
v
=
f
->
params
[
i
];
std
::
string
vid
=
AllocVarID
(
v
.
get
());
if
(
i
!=
0
)
stream
<<
", "
;
if
(
i
!=
0
)
stream
<<
", "
;
if
(
v
.
dtype
().
is_handle
())
{
// work around for grid constant parameters.
if
(
auto
*
ptr
=
v
->
type_annotation
.
as
<
PointerTypeNode
>
())
{
if
(
auto
*
ptr
=
v
->
type_annotation
.
as
<
PointerTypeNode
>
())
{
if
(
ptr
->
storage_scope
==
"grid_constant"
)
{
stream
<<
"__grid_constant__ const "
;
CodeGenC
::
PrintType
(
ptr
->
element_type
,
stream
);
...
...
@@ -1236,8 +1306,8 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) {
}
CodeGenC
::
PrintType
(
GetType
(
v
),
stream
);
if
(
auto
*
ptr
=
v
->
type_annotation
.
as
<
PointerTypeNode
>
())
{
if
(
auto
*
prim
=
ptr
->
element_type
.
as
<
PrimTypeNode
>
())
{
if
(
auto
*
ptr
=
v
->
type_annotation
.
as
<
PointerTypeNode
>
())
{
if
(
auto
*
prim
=
ptr
->
element_type
.
as
<
PrimTypeNode
>
())
{
RegisterHandleType
(
v
.
get
(),
prim
->
dtype
);
}
}
...
...
@@ -1259,5 +1329,5 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) {
this
->
stream
<<
"}
\n\n
"
;
}
}
// namespace codegen
}
// namespace tvm
}
// namespace codegen
}
// namespace tvm
src/target/codegen_hip.h
View file @
549416f7
...
...
@@ -21,50 +21,58 @@ namespace tvm {
namespace
codegen
{
class
CodeGenTileLangHIP
final
:
public
CodeGenC
{
public:
public:
CodeGenTileLangHIP
();
std
::
string
Finish
();
// override behavior
void
PrintFuncPrefix
(
std
::
ostream
&
os
)
final
;
void
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
PrintStorageSync
(
const
CallNode
*
op
)
final
;
void
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
PrimExpr
rhs
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
const
std
::
string
&
value
)
final
;
void
BindThreadIndex
(
const
IterVar
&
iv
)
final
;
// NOLINT(*)
void
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
final
;
std
::
string
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
final
;
void
PrintFuncPrefix
(
std
::
ostream
&
os
)
final
;
void
PrintExtraAttrs
(
const
PrimFunc
&
f
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
PrintStorageSync
(
const
CallNode
*
op
)
final
;
void
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecBinaryOp
(
const
std
::
string
&
op
,
DataType
t
,
PrimExpr
lhs
,
PrimExpr
rhs
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemLoad
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
DataType
t
,
int
i
,
const
std
::
string
&
value
)
final
;
void
BindThreadIndex
(
const
IterVar
&
iv
)
final
;
// NOLINT(*)
void
PrintVecElemLoadExpr
(
DataType
t
,
int
i
,
const
std
::
string
&
value
,
std
::
ostream
&
os
)
final
;
std
::
string
CastFromTo
(
std
::
string
value
,
DataType
from
,
DataType
target
)
final
;
// overload visitor
void
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
;
void
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
;
// Override this as a work around for __grid_constant__ parameter
void
AddFunction
(
const
PrimFunc
&
f
);
void
AddFunction
(
const
PrimFunc
&
f
);
protected:
virtual
std
::
string
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
final
;
void
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
protected:
virtual
std
::
string
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
final
;
void
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
private:
private:
// Handle volatile loads
void
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
final
;
void
HandleVolatileLoads
(
const
std
::
string
&
value
,
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
final
;
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool
IsScopePartOfType
()
const
final
{
return
false
;
}
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangHIP
*
p
);
friend
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangHIP
*
p
);
// whether need math_constants.h
bool
need_math_constants_h_
{
false
};
...
...
@@ -83,7 +91,7 @@ class CodeGenTileLangHIP final : public CodeGenC {
const
int
barrier_alignment_bytes_
=
16
;
};
}
// namespace codegen
}
// namespace tvm
}
// namespace codegen
}
// namespace tvm
#endif
// TVM_TL_TARGET_CODEGEN_HIP_H_
#endif // TVM_TL_TARGET_CODEGEN_HIP_H_
src/target/cuda.h
View file @
549416f7
This source diff could not be displayed because it is too large. You can
view the blob
instead.
src/target/rt_mod_cuda.cc
View file @
549416f7
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "runtime/cuda/cuda_module.h"
#include "codegen_cuda.h"
#include "runtime/cuda/cuda_module.h"
namespace
tvm
{
namespace
codegen
{
static
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
ExtractFuncInfo
(
const
IRModule
&
mod
)
{
static
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
ExtractFuncInfo
(
const
IRModule
&
mod
)
{
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
fmap
;
for
(
auto
kv
:
mod
->
functions
)
{
ICHECK
(
kv
.
second
->
IsInstance
<
tir
::
PrimFuncNode
>
())
<<
"Can only lower IR Module with PrimFuncs"
;
ICHECK
(
kv
.
second
->
IsInstance
<
tir
::
PrimFuncNode
>
())
<<
"Can only lower IR Module with PrimFuncs"
;
auto
f
=
Downcast
<
tir
::
PrimFunc
>
(
kv
.
second
);
runtime
::
FunctionInfo
info
;
...
...
@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info
.
arg_types
.
push_back
(
f
->
params
[
i
].
dtype
());
}
if
(
auto
opt
=
f
->
GetAttr
<
Array
<
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
info
.
launch_param_tags
.
push_back
(
tag
);
}
}
...
...
@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
cg
.
Init
(
output_ssa
);
for
(
auto
kv
:
mod
->
functions
)
{
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangCUDA: Can only take PrimFunc"
;
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangCUDA: Can only take PrimFunc"
;
auto
f
=
Downcast
<
PrimFunc
>
(
kv
.
second
);
auto
calling_conv
=
f
->
GetAttr
<
Integer
>
(
tvm
::
attr
::
kCallingConv
);
ICHECK
(
calling_conv
==
CallingConv
::
kDeviceKernelLaunch
);
...
...
@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
}
std
::
string
code
=
cg
.
Finish
();
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_postproc"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_postproc"
))
{
code
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
}
std
::
string
fmt
=
"ptx"
;
std
::
string
ptx
;
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_compile"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_compile"
))
{
ptx
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
if
(
ptx
[
0
]
!=
'/'
)
fmt
=
"cubin"
;
if
(
ptx
[
0
]
!=
'/'
)
fmt
=
"cubin"
;
}
else
{
ICHECK
(
0
);
}
...
...
@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) {
cg
.
Init
(
output_ssa
);
for
(
auto
kv
:
mod
->
functions
)
{
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangCUDA: Can only take PrimFunc"
;
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangCUDA: Can only take PrimFunc"
;
auto
f
=
Downcast
<
PrimFunc
>
(
kv
.
second
);
auto
calling_conv
=
f
->
GetAttr
<
Integer
>
(
tvm
::
attr
::
kCallingConv
);
ICHECK
(
calling_conv
==
CallingConv
::
kDeviceKernelLaunch
);
...
...
@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) {
}
std
::
string
code
=
cg
.
Finish
();
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_postproc"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_cuda_postproc"
))
{
code
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
}
return
String
(
code
);
}
TVM_REGISTER_GLOBAL
(
"target.build.tilelang_cuda"
).
set_body_typed
(
BuildTileLangCUDA
);
TVM_REGISTER_GLOBAL
(
"target.build.tl_debug_codegen"
).
set_body_typed
(
BuildTLDebug
);
TVM_REGISTER_GLOBAL
(
"target.build.tilelang_cuda"
)
.
set_body_typed
(
BuildTileLangCUDA
);
TVM_REGISTER_GLOBAL
(
"target.build.tl_debug_codegen"
)
.
set_body_typed
(
BuildTLDebug
);
}
// namespace codegen
}
// namespace tvm
}
// namespace codegen
}
// namespace tvm
src/target/rt_mod_hip.cc
View file @
549416f7
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#if defined(__linux__)
#include <sys/stat.h>
#endif
...
...
@@ -8,28 +8,28 @@
#include <hip/hip_runtime.h>
#include <hip/hiprtc.h>
#include "runtime/rocm/rocm_module.h"
#include "codegen_hip.h"
#include "runtime/rocm/rocm_module.h"
namespace
tvm
{
namespace
codegen
{
#define HIPRTC_CALL(x) \
#define HIPRTC_CALL(x) \
\
{
\
{
\
\
hiprtcResult
result
=
x
;
\
hiprtcResult
result
=
x
;
\
\
if
(
result
!=
HIPRTC_SUCCESS
)
{
\
if
(
result
!=
HIPRTC_SUCCESS
)
{
\
\
LOG
(
FATAL
)
\
<<
"HiprtcError: "
#
x
" failed with error: "
<<
hiprtcGetErrorString
(
result
);
\
LOG
(
FATAL
)
\
<<
"HiprtcError: "
#
x
" failed with error: "
\
<<
hiprtcGetErrorString
(
result
);
\
\
\
}
\
\
}
\
\
\
\
}
static
std
::
string
FindHIPIncludePath
()
{
...
...
@@ -39,7 +39,7 @@ static std::string FindHIPIncludePath() {
const
std
::
string
delimiter
=
"/"
;
#endif
std
::
string
hip_include_path
;
const
char
*
hip_path_env
=
std
::
getenv
(
"HIP_PATH"
);
const
char
*
hip_path_env
=
std
::
getenv
(
"HIP_PATH"
);
if
(
hip_path_env
!=
nullptr
)
{
hip_include_path
+=
hip_path_env
;
hip_include_path
+=
delimiter
+
"include"
;
...
...
@@ -58,19 +58,24 @@ static std::string FindHIPIncludePath() {
}
#endif
LOG
(
FATAL
)
<<
"Cannot find HIP include path."
<<
"HIP_PATH is not set or ROCm is not installed in the default installation path."
<<
"HIP_PATH is not set or ROCm is not installed in the default "
"installation path."
<<
"In other than linux, it is necessary to set HIP_PATH."
;
return
hip_include_path
;
}
static
std
::
string
HIPRTCCompile
(
const
std
::
string
&
code
,
bool
include_path
=
false
)
{
static
std
::
string
HIPRTCCompile
(
const
std
::
string
&
code
,
bool
include_path
=
false
)
{
std
::
vector
<
std
::
string
>
compile_params
;
std
::
vector
<
const
char
*>
param_cstrings
{};
std
::
vector
<
const
char
*>
param_cstrings
{};
hiprtcProgram
prog
;
std
::
string
cc
=
"gfx900"
;
// Default target architecture (can be changed as needed)
std
::
string
cc
=
"gfx900"
;
// Default target architecture (can be changed as needed)
int
major
,
minor
;
hipError_t
e1
=
hipDeviceGetAttribute
(
&
major
,
hipDeviceAttributeComputeCapabilityMajor
,
0
);
hipError_t
e2
=
hipDeviceGetAttribute
(
&
minor
,
hipDeviceAttributeComputeCapabilityMinor
,
0
);
hipError_t
e1
=
hipDeviceGetAttribute
(
&
major
,
hipDeviceAttributeComputeCapabilityMajor
,
0
);
hipError_t
e2
=
hipDeviceGetAttribute
(
&
minor
,
hipDeviceAttributeComputeCapabilityMinor
,
0
);
if
(
e1
==
hipSuccess
&&
e2
==
hipSuccess
)
{
cc
=
"gfx"
+
std
::
to_string
(
major
*
100
+
minor
*
10
);
...
...
@@ -86,10 +91,11 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
compile_params
.
push_back
(
include_option
);
}
for
(
const
auto
&
string
:
compile_params
)
{
for
(
const
auto
&
string
:
compile_params
)
{
param_cstrings
.
push_back
(
string
.
c_str
());
}
HIPRTC_CALL
(
hiprtcCreateProgram
(
&
prog
,
code
.
c_str
(),
nullptr
,
0
,
nullptr
,
nullptr
));
HIPRTC_CALL
(
hiprtcCreateProgram
(
&
prog
,
code
.
c_str
(),
nullptr
,
0
,
nullptr
,
nullptr
));
hiprtcResult
compile_res
=
hiprtcCompileProgram
(
prog
,
param_cstrings
.
size
(),
param_cstrings
.
data
());
...
...
@@ -110,11 +116,13 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
return
code_out
;
}
static
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
ExtractFuncInfo
(
const
IRModule
&
mod
)
{
static
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
ExtractFuncInfo
(
const
IRModule
&
mod
)
{
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
fmap
;
for
(
auto
kv
:
mod
->
functions
)
{
ICHECK
(
kv
.
second
->
IsInstance
<
tir
::
PrimFuncNode
>
())
<<
"Can only lower IR Module with PrimFuncs"
;
ICHECK
(
kv
.
second
->
IsInstance
<
tir
::
PrimFuncNode
>
())
<<
"Can only lower IR Module with PrimFuncs"
;
auto
f
=
Downcast
<
tir
::
PrimFunc
>
(
kv
.
second
);
runtime
::
FunctionInfo
info
;
...
...
@@ -129,7 +137,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info
.
arg_types
.
push_back
(
f
->
params
[
i
].
dtype
());
}
if
(
auto
opt
=
f
->
GetAttr
<
Array
<
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
info
.
launch_param_tags
.
push_back
(
tag
);
}
}
...
...
@@ -146,7 +154,8 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
cg
.
Init
(
output_ssa
);
for
(
auto
kv
:
mod
->
functions
)
{
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangHIP: Can only take PrimFunc"
;
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangHIP: Can only take PrimFunc"
;
auto
f
=
Downcast
<
PrimFunc
>
(
kv
.
second
);
auto
calling_conv
=
f
->
GetAttr
<
Integer
>
(
tvm
::
attr
::
kCallingConv
);
ICHECK
(
calling_conv
==
CallingConv
::
kDeviceKernelLaunch
);
...
...
@@ -154,21 +163,23 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
}
std
::
string
code
=
cg
.
Finish
();
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_hip_postproc"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_hip_postproc"
))
{
code
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
}
std
::
string
fmt
=
"ptx"
;
std
::
string
ptx
;
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_hip_compile"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_hip_compile"
))
{
ptx
=
(
*
f
)(
code
,
target
).
operator
std
::
string
();
if
(
ptx
[
0
]
!=
'/'
)
fmt
=
"hsaco"
;
if
(
ptx
[
0
]
!=
'/'
)
fmt
=
"hsaco"
;
}
else
{
ptx
=
HIPRTCCompile
(
code
,
false
);
}
return
ROCMModuleCreate
(
ptx
,
fmt
,
ExtractFuncInfo
(
mod
),
code
,
std
::
string
());
}
TVM_REGISTER_GLOBAL
(
"target.build.tilelang_hip"
).
set_body_typed
(
BuildTileLangHIP
);
TVM_REGISTER_GLOBAL
(
"target.build.tilelang_hip"
)
.
set_body_typed
(
BuildTileLangHIP
);
}
// namespace codegen
}
// namespace tvm
}
// namespace codegen
}
// namespace tvm
src/target/utils.cc
View file @
549416f7
...
...
@@ -11,13 +11,17 @@
namespace
tvm
{
namespace
tl
{
bool
TargetIsCuda
(
Target
target
)
{
return
target
->
GetTargetDeviceType
()
==
kDLCUDA
;
}
bool
TargetIsRocm
(
Target
target
)
{
return
target
->
GetTargetDeviceType
()
==
kDLROCM
;
}
bool
TargetIsCuda
(
Target
target
)
{
return
target
->
GetTargetDeviceType
()
==
kDLCUDA
;
}
bool
TargetIsRocm
(
Target
target
)
{
return
target
->
GetTargetDeviceType
()
==
kDLROCM
;
}
int
GetArchInt
(
Target
target
)
{
auto
s
=
target
->
GetAttr
<
String
>
(
"arch"
);
ICHECK
(
s
.
defined
());
const
char
*
arch_str
=
s
.
value
().
c_str
();
const
char
*
arch_str
=
s
.
value
().
c_str
();
ICHECK_EQ
(
arch_str
[
0
],
's'
);
ICHECK_EQ
(
arch_str
[
1
],
'm'
);
ICHECK_EQ
(
arch_str
[
2
],
'_'
);
...
...
@@ -25,31 +29,36 @@ int GetArchInt(Target target) {
}
bool
TargetIsVolta
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
return
arch
>=
70
&&
arch
<
75
;
}
bool
TargetIsTuring
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
return
arch
>=
75
&&
arch
<
80
;
}
bool
TargetIsAmpere
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
return
arch
>=
80
&&
arch
<
90
;
}
bool
TargetIsHopper
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
return
arch
>=
90
;
}
bool
TargetIsCDNA
(
Target
target
)
{
if
(
!
TargetIsRocm
(
target
))
return
false
;
if
(
!
TargetIsRocm
(
target
))
return
false
;
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
std
::
string
mcpu
=
Downcast
<
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
// if mcpu start with "gfx9", it is CDNA
...
...
@@ -78,16 +87,18 @@ bool TargetHasAsyncCopy(Target target) {
return
false
;
}
bool
TargetHasLdmatrix
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
return
arch
>=
75
;
}
bool
TargetHasStmatrix
(
Target
target
)
{
if
(
!
TargetIsCuda
(
target
))
return
false
;
if
(
!
TargetIsCuda
(
target
))
return
false
;
int
arch
=
GetArchInt
(
target
);
return
arch
>=
90
;
}
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/target/utils.h
View file @
549416f7
...
...
@@ -23,12 +23,12 @@ bool TargetIsTuring(Target target);
bool
TargetIsAmpere
(
Target
target
);
bool
TargetIsHopper
(
Target
target
);
bool
TargetIsCDNA
(
Target
target
);
bool
TargetHasAsyncCopy
(
Target
target
);
bool
TargetHasLdmatrix
(
Target
target
);
bool
TargetHasStmatrix
(
Target
target
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
#endif
// TVM_TL_TARGET_UTILS_H_
#endif // TVM_TL_TARGET_UTILS_H_
src/tl_templates/cuda/common.h
View file @
549416f7
...
...
@@ -25,56 +25,57 @@ using cutlass::tfloat32_t;
// Pack two half values.
TL_DEVICE
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
// Pack two half_t values.
TL_DEVICE
unsigned
__pack_half2
(
const
half_t
x
,
const
half_t
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
// Pack two bfloat16_t values.
TL_DEVICE
unsigned
__pack_half2
(
const
bfloat16_t
x
,
const
bfloat16_t
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
/// Helper to cast SMEM pointer to unsigned
TL_DEVICE
uint32_t
smem_ptr_to_uint
(
void
const
*
const
ptr
)
{
TL_DEVICE
uint32_t
smem_ptr_to_uint
(
void
const
*
const
ptr
)
{
return
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
ptr
));
}
// AtomicAdd Functions for FP16
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
half_t
val
)
{
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
half_t
val
)
{
// Use atomicCAS with built-in cuda_fp16 support
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
static_cast
<
half
>
(
val
));
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
static_cast
<
half
>
(
val
));
}
// AtomicAdd Functions for FP16
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
half_t
*
val
)
{
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
static_cast
<
half
>
(
*
val
));
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
half_t
*
val
)
{
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
static_cast
<
half
>
(
*
val
));
}
// AtomicAdd Functions for FP16
TL_DEVICE
void
atomicAddx2
(
half_t
*
address
,
half_t
*
val
)
{
atomicAdd
(
reinterpret_cast
<
half2
*>
(
address
),
static_cast
<
half2
>
(
*
reinterpret_cast
<
half2
*>
(
val
)));
TL_DEVICE
void
atomicAddx2
(
half_t
*
address
,
half_t
*
val
)
{
atomicAdd
(
reinterpret_cast
<
half2
*>
(
address
),
static_cast
<
half2
>
(
*
reinterpret_cast
<
half2
*>
(
val
)));
}
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
float
val
)
{
TL_DEVICE
void
atomicAdd
(
half_t
*
address
,
float
val
)
{
// Use atomicCAS with built-in cuda_fp16 support
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
__float2half
(
val
));
atomicAdd
(
reinterpret_cast
<
half
*>
(
address
),
__float2half
(
val
));
}
// DP4A
template
<
typename
InDatatype
,
typename
OutDatatype
>
TL_DEVICE
void
DP4A
(
InDatatype
*
a
,
InDatatype
*
b
,
OutDatatype
*
c
)
{
const
int
a_int
=
*
((
int
*
)
a
);
const
int
b_int
=
*
((
int
*
)
b
);
const
int
c_int
=
*
((
int
*
)
c
);
template
<
typename
InDatatype
,
typename
OutDatatype
>
TL_DEVICE
void
DP4A
(
InDatatype
*
a
,
InDatatype
*
b
,
OutDatatype
*
c
)
{
const
int
a_int
=
*
((
int
*
)
a
);
const
int
b_int
=
*
((
int
*
)
b
);
const
int
c_int
=
*
((
int
*
)
c
);
*
c
=
__dp4a
(
a_int
,
b_int
,
c_int
);
}
src/tl_templates/cuda/copy.h
View file @
549416f7
...
...
@@ -10,10 +10,11 @@
namespace
tl
{
TL_DEVICE
void
cp_async_commit
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
TL_DEVICE
void
cp_async_commit
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
template
<
int
N
>
TL_DEVICE
void
cp_async_wait
()
{
template
<
int
N
>
TL_DEVICE
void
cp_async_wait
()
{
if
constexpr
(
N
==
0
)
{
asm
volatile
(
"cp.async.wait_all;
\n
"
::
);
}
else
{
...
...
@@ -22,7 +23,7 @@ TL_DEVICE void cp_async_wait() {
}
template
<
int
N
>
TL_DEVICE
void
cp_async_gs
(
void
const
*
const
smem_addr
,
void
*
global_ptr
)
{
TL_DEVICE
void
cp_async_gs
(
void
const
*
const
smem_addr
,
void
*
global_ptr
)
{
static_assert
(
N
==
16
||
N
==
8
||
N
==
4
);
unsigned
int
addr
=
smem_ptr_to_uint
(
smem_addr
);
if
constexpr
(
N
==
16
)
{
...
...
@@ -33,7 +34,7 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
::
"r"
(
addr
),
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
));
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
));
}
else
{
__asm__
__volatile__
(
#if TL_ENABLE_L2_PREFETCH
...
...
@@ -42,12 +43,13 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
"cp.async.ca.shared.global [%0], [%1], %2;"
#endif
::
"r"
(
addr
),
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
));
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
));
}
}
template
<
int
N
>
TL_DEVICE
void
cp_async_gs_conditional
(
void
const
*
const
smem_addr
,
void
*
global_ptr
,
bool
cond
)
{
TL_DEVICE
void
cp_async_gs_conditional
(
void
const
*
const
smem_addr
,
void
*
global_ptr
,
bool
cond
)
{
static_assert
(
N
==
16
||
N
==
8
||
N
==
4
);
int
bytes
=
cond
?
N
:
0
;
unsigned
int
addr
=
smem_ptr_to_uint
(
smem_addr
);
...
...
@@ -59,7 +61,7 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
"cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif
::
"r"
(
addr
),
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
),
"r"
(
bytes
));
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
),
"r"
(
bytes
));
}
else
{
__asm__
__volatile__
(
#if TL_ENABLE_L2_PREFETCH
...
...
@@ -68,8 +70,8 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
"cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif
::
"r"
(
addr
),
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
),
"r"
(
bytes
));
"l"
((
void
*
)(
global_ptr
)),
"n"
(
N
),
"r"
(
bytes
));
}
}
}
// namespace tl
}
// namespace tl
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment