Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
832 additions
and
1057 deletions
+832
-1057
src/op/reduce.h
src/op/reduce.h
+29
-40
src/op/region.cc
src/op/region.cc
+4
-2
src/op/region.h
src/op/region.h
+4
-17
src/op/tcgen5_meta.h
src/op/tcgen5_meta.h
+163
-0
src/runtime/runtime.cc
src/runtime/runtime.cc
+4
-4
src/support/ffi_aliases.h
src/support/ffi_aliases.h
+16
-0
src/target/codegen_cpp.cc
src/target/codegen_cpp.cc
+4
-4
src/target/codegen_cpp.h
src/target/codegen_cpp.h
+4
-4
src/target/codegen_cuda.cc
src/target/codegen_cuda.cc
+545
-75
src/target/codegen_cuda.h
src/target/codegen_cuda.h
+16
-4
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+4
-4
src/target/codegen_hip.h
src/target/codegen_hip.h
+2
-2
src/target/codegen_webgpu.cc
src/target/codegen_webgpu.cc
+0
-786
src/target/codegen_webgpu.h
src/target/codegen_webgpu.h
+0
-104
src/target/intrin_rule_cuda.cc
src/target/intrin_rule_cuda.cc
+1
-0
src/target/intrin_rule_hip.cc
src/target/intrin_rule_hip.cc
+1
-0
src/target/ptx.cc
src/target/ptx.cc
+17
-2
src/target/ptx.h
src/target/ptx.h
+5
-0
src/target/rt_mod_cpp.cc
src/target/rt_mod_cpp.cc
+6
-3
src/target/rt_mod_cuda.cc
src/target/rt_mod_cuda.cc
+7
-6
No files found.
src/op/reduce.h
View file @
bbbf4207
...
...
@@ -30,23 +30,13 @@ enum class ReduceTypeEnum : uint8_t {
class
ReduceTypeNode
:
public
Object
{
public:
int
type
{
-
1
};
///< Internal type identifier
static
constexpr
const
char
*
_type_key
=
"tl.ReduceType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
ReduceTypeNode
,
Object
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.ReduceType"
,
ReduceTypeNode
,
Object
);
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
ObjectDef
<
ReduceTypeNode
>
().
def_ro
(
"type"
,
&
ReduceTypeNode
::
type
);
}
bool
SEqualReduce
(
const
ReduceTypeNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
type
,
other
->
type
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
type
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
/// Type checking methods
bool
isSum
()
const
{
return
type
==
int
(
ReduceTypeEnum
::
kSum
);
}
bool
isAbsSum
()
const
{
return
type
==
int
(
ReduceTypeEnum
::
kAbsSum
);
}
...
...
@@ -61,9 +51,10 @@ public:
/// Wrapper class for reduction type with string-based construction
class
ReduceType
:
public
ObjectRef
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
ReduceType
,
ObjectRef
,
ReduceTypeNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
ReduceType
,
ObjectRef
,
ReduceTypeNode
);
TVM_DLL
ReduceType
(
std
::
string
type
)
{
auto
node
=
make_object
<
ReduceTypeNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
ReduceTypeNode
>
();
if
(
type
==
"sum"
)
{
node
->
type
=
int
(
ReduceTypeEnum
::
kSum
);
}
else
if
(
type
==
"abssum"
)
{
...
...
@@ -91,40 +82,27 @@ public:
class
ReduceOpNode
:
public
TileOperatorNode
{
public:
tir
::
Buffer
src
,
dst
;
///< Source and destination buffers
// Optional: keep the original regions used to construct this op
BufferRegion
srcRegion_
,
dstRegion_
;
int
dim
;
///< Dimension to reduce along
ReduceType
type
;
///< Type of reduction operation
bool
clear
;
///< Whether to clear destination before reduction
static
constexpr
const
char
*
_type_key
=
"tl.ReduceOp"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
ReduceOpNode
,
TileOperatorNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.ReduceOp"
,
ReduceOpNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
ObjectDef
<
ReduceOpNode
>
()
.
def_ro
(
"src"
,
&
ReduceOpNode
::
src
)
.
def_ro
(
"dst"
,
&
ReduceOpNode
::
dst
)
.
def_ro
(
"srcRegion"
,
&
ReduceOpNode
::
srcRegion_
)
.
def_ro
(
"dstRegion"
,
&
ReduceOpNode
::
dstRegion_
)
.
def_ro
(
"dim"
,
&
ReduceOpNode
::
dim
)
.
def_ro
(
"type"
,
&
ReduceOpNode
::
type
)
.
def_ro
(
"clear"
,
&
ReduceOpNode
::
clear
);
}
bool
SEqualReduce
(
const
ReduceOpNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
src
,
other
->
src
)
&&
equal
(
dst
,
other
->
dst
)
&&
equal
(
dim
,
other
->
dim
)
&&
equal
(
type
,
other
->
type
)
&&
equal
(
clear
,
other
->
clear
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
src
);
hash_reduce
(
dst
);
hash_reduce
(
dim
);
hash_reduce
(
type
);
hash_reduce
(
clear
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
/// Lower the operator to TIR statements
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
/// Infer memory layout for buffers
...
...
@@ -145,7 +123,8 @@ private:
/// Wrapper class for reduction operations
class
ReduceOp
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
ReduceOp
,
TileOperator
,
ReduceOpNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
ReduceOp
,
TileOperator
,
ReduceOpNode
);
TVM_DLL
ReduceOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
};
...
...
@@ -156,8 +135,17 @@ public:
tir
::
Buffer
src
,
dst
;
///< Source and destination buffers
int
dim
;
///< Dimension along which to compute cumulative sum
bool
reverse
;
///< Whether to compute in reverse order
static
constexpr
const
char
*
_type_key
=
"tl.CumSumOp"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
CumSumOpNode
,
TileOperatorNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.CumSumOp"
,
CumSumOpNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
ObjectDef
<
CumSumOpNode
>
()
.
def_ro
(
"src"
,
&
CumSumOpNode
::
src
)
.
def_ro
(
"dst"
,
&
CumSumOpNode
::
dst
)
.
def_ro
(
"dim"
,
&
CumSumOpNode
::
dim
)
.
def_ro
(
"reverse"
,
&
CumSumOpNode
::
reverse
);
}
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
...
...
@@ -169,7 +157,8 @@ public:
/// Wrapper class for cumulative sum operations
class
CumSumOp
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
CumSumOp
,
TileOperator
,
CumSumOpNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
CumSumOp
,
TileOperator
,
CumSumOpNode
);
TVM_DLL
CumSumOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
};
...
...
src/op/region.cc
View file @
bbbf4207
...
...
@@ -44,7 +44,7 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
PrimExpr
extent
=
args
[
2
+
i
];
ranges
.
push_back
(
Range
::
FromMinExtent
(
min
,
extent
));
}
ObjectPtr
<
RegionOpNode
>
node
=
make_object
<
RegionOpNode
>
();
ObjectPtr
<
RegionOpNode
>
node
=
tvm
::
ffi
::
make_object
<
RegionOpNode
>
();
node
->
buffer_
=
load
->
buffer
;
node
->
access_mask_
=
static_cast
<
int
>
(
*
as_const_int
(
args
[
1
]));
node
->
ranges_
=
ranges
;
...
...
@@ -57,7 +57,7 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A new TileOperator that owns a copied RegionOpNode.
*/
TileOperator
RegionOpNode
::
Clone
()
const
{
auto
op
=
make_object
<
RegionOpNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
RegionOpNode
>
(
*
this
);
return
RegionOp
(
op
);
}
...
...
@@ -118,5 +118,7 @@ TIR_REGISTER_TL_OP(RegionOp, region)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kPure
));
TVM_FFI_STATIC_INIT_BLOCK
()
{
RegionOpNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tvm
src/op/region.h
View file @
bbbf4207
...
...
@@ -80,8 +80,8 @@ public:
Array
<
Range
>
ranges_
;
int
access_mask_
;
static
constexpr
const
char
*
_type_key
=
"tl.RegionOp"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
RegionOpNode
,
TileOperatorNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.RegionOp"
,
RegionOpNode
,
TileOperatorNode
);
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
...
...
@@ -101,25 +101,12 @@ public:
.
def_ro
(
"ranges"
,
&
RegionOpNode
::
ranges_
)
.
def_ro
(
"access_mask"
,
&
RegionOpNode
::
access_mask_
);
}
bool
SEqualReduce
(
const
RegionOpNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
buffer_
,
other
->
buffer_
)
&&
equal
(
ranges_
,
other
->
ranges_
)
&&
equal
(
access_mask_
,
other
->
access_mask_
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
buffer_
);
hash_reduce
(
ranges_
);
hash_reduce
(
access_mask_
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
};
class
RegionOp
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
RegionOp
,
TileOperator
,
RegionOpNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
RegionOp
,
TileOperator
,
RegionOpNode
);
TVM_DLL
RegionOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
...
...
src/op/tcgen5_meta.h
0 → 100644
View file @
bbbf4207
#ifndef TVM_TL_OP_TCGEN5_META_H_
#define TVM_TL_OP_TCGEN5_META_H_
#include <cstdint>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <utility>
#include <vector>
namespace
tvm
{
namespace
tl
{
using
runtime
::
DataType
;
struct
TCGEN5MMAMeta
{
int
atom_m
,
atom_n
,
atom_k
;
};
inline
std
::
pair
<
bool
,
TCGEN5MMAMeta
>
GetTCGEN5MMAMeta
(
int
M
,
int
N
,
int
K
,
DataType
ab_dtype
,
DataType
c_dtype
)
{
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std
::
vector
<
int
>
ws_valid_atom_ns
=
{
256
,
128
,
64
};
if
((
ab_dtype
.
is_bfloat16
()
||
ab_dtype
.
is_float16
())
&&
(
c_dtype
.
is_float
()
&&
c_dtype
.
bits
()
==
32
))
{
if
(
K
%
16
!=
0
)
FAIL
;
if
(
M
%
128
==
0
)
{
for
(
int
atom_n
=
256
;
atom_n
>=
16
;
atom_n
-=
16
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
128
,
atom_n
,
16
);
FAIL
;
}
else
if
(
M
%
64
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
64
,
atom_n
,
16
);
FAIL
;
}
else
if
(
M
%
32
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
32
,
atom_n
,
16
);
FAIL
;
}
else
{
FAIL
;
}
}
else
if
((
ab_dtype
.
is_float8_e4m3fn
()
||
ab_dtype
.
is_float8_e5m2
())
&&
(
c_dtype
.
is_float
()
&&
c_dtype
.
bits
()
==
32
))
{
if
(
K
%
32
!=
0
)
FAIL
;
if
(
M
%
128
==
0
)
{
for
(
int
atom_n
=
256
;
atom_n
>=
16
;
atom_n
-=
16
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
128
,
atom_n
,
32
);
FAIL
;
}
else
if
(
M
%
64
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
64
,
atom_n
,
32
);
FAIL
;
}
else
if
(
M
%
32
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
32
,
atom_n
,
32
);
FAIL
;
}
else
{
FAIL
;
}
}
FAIL
;
#undef FAIL
#undef SUCCESS
}
inline
uint32_t
GetTCGEN5InstrDesc
(
int
atom_m
,
int
atom_n
,
int
atom_k
,
DataType
ab_dtype
,
DataType
c_dtype
,
bool
a_is_k_major
,
bool
b_is_k_major
,
int
scale_in_a
,
int
scale_in_b
)
{
ICHECK
(
atom_m
%
16
==
0
)
<<
"atom_m must be divisible by 16"
;
ICHECK
(
atom_n
%
8
==
0
)
<<
"atom_n must be divisible by 8"
;
ICHECK
(
atom_k
==
16
||
atom_k
==
32
)
<<
"Unsupported atom_k for TCGEN5MMA descriptor: "
<<
atom_k
;
ICHECK
(
scale_in_a
==
1
||
scale_in_a
==
-
1
)
<<
"scale_in_a must be +/-1 for TCGEN5MMA"
;
ICHECK
(
scale_in_b
==
1
||
scale_in_b
==
-
1
)
<<
"scale_in_b must be +/-1 for TCGEN5MMA"
;
auto
encode_dtype
=
[
&
](
DataType
dtype
)
->
uint32_t
{
if
(
dtype
.
is_float16
())
{
return
static_cast
<
uint32_t
>
(
0
);
}
else
if
(
dtype
.
is_bfloat16
())
{
return
static_cast
<
uint32_t
>
(
1
);
}
else
if
(
dtype
.
is_float8_e4m3fn
()
||
dtype
.
is_float8_e4m3fnuz
()
||
dtype
.
is_float8_e4m3
())
{
return
static_cast
<
uint32_t
>
(
0
);
}
else
if
(
dtype
.
is_float8_e5m2fnuz
()
||
dtype
.
is_float8_e5m2
())
{
return
static_cast
<
uint32_t
>
(
1
);
}
LOG
(
FATAL
)
<<
"Unsupported dtype for TCGEN5MMA descriptor: "
<<
dtype
;
return
0u
;
};
uint32_t
a_format
=
encode_dtype
(
ab_dtype
);
uint32_t
b_format
=
a_format
;
uint32_t
c_format
=
0
;
if
(
c_dtype
.
is_float16
())
{
c_format
=
0
;
}
else
if
(
c_dtype
.
is_float
())
{
c_format
=
1
;
}
else
if
(
c_dtype
.
is_int
())
{
c_format
=
2
;
}
else
{
LOG
(
FATAL
)
<<
"Unsupported accumulator dtype for TCGEN5MMA descriptor: "
<<
c_dtype
;
}
auto
set_bits
=
[](
uint32_t
value
,
int
start
,
int
width
)
->
uint32_t
{
uint32_t
mask
=
(
width
==
32
)
?
0xFFFFFFFFu
:
((
1u
<<
width
)
-
1
);
return
(
value
&
mask
)
<<
start
;
};
uint32_t
desc
=
0
;
desc
|=
set_bits
(
0
,
0
,
2
);
// sparse_id2
desc
|=
set_bits
(
0
,
2
,
1
);
// sparse_flag
desc
|=
set_bits
(
0
,
3
,
1
);
// saturate
desc
|=
set_bits
(
c_format
,
4
,
2
);
desc
|=
set_bits
(
a_format
,
7
,
3
);
desc
|=
set_bits
(
b_format
,
10
,
3
);
uint32_t
a_neg
=
(
scale_in_a
==
-
1
)
?
1u
:
0u
;
uint32_t
b_neg
=
(
scale_in_b
==
-
1
)
?
1u
:
0u
;
desc
|=
set_bits
(
a_neg
,
13
,
1
);
desc
|=
set_bits
(
b_neg
,
14
,
1
);
uint32_t
a_major
=
a_is_k_major
?
0u
:
1u
;
uint32_t
b_major
=
b_is_k_major
?
0u
:
1u
;
desc
|=
set_bits
(
a_major
,
15
,
1
);
desc
|=
set_bits
(
b_major
,
16
,
1
);
uint32_t
n_dim
=
static_cast
<
uint32_t
>
(
atom_n
>>
3
);
uint32_t
m_dim
=
static_cast
<
uint32_t
>
(
atom_m
>>
4
);
desc
|=
set_bits
(
n_dim
,
17
,
6
);
desc
|=
set_bits
(
0
,
23
,
1
);
desc
|=
set_bits
(
m_dim
,
24
,
5
);
desc
|=
set_bits
(
0
,
29
,
1
);
uint32_t
max_shift
=
0u
;
desc
|=
set_bits
(
max_shift
,
30
,
2
);
return
desc
;
}
}
// namespace tl
}
// namespace tvm
#endif // TVM_TL_OP_TCGEN5_META_H_
src/runtime/runtime.cc
View file @
bbbf4207
...
...
@@ -89,7 +89,7 @@ struct TensorMapArgs {
};
// set device api
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def_packed
(
"tvm_tensormap_create_tiled"
,
[](
PackedArgs
args
,
Any
*
ret
)
{
...
...
@@ -104,7 +104,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
}
*
ret
=
static_cast
<
int
>
(
result
);
});
}
);
}
struct
TensorMapIm2ColArgs
{
CUtensorMap
*
map
;
...
...
@@ -180,7 +180,7 @@ struct TensorMapIm2ColArgs {
}
};
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def_packed
(
"tvm_tensormap_create_im2col"
,
[](
PackedArgs
args
,
Any
*
ret
)
{
...
...
@@ -197,7 +197,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
}
*
ret
=
static_cast
<
int
>
(
result
);
});
}
);
}
#endif // (CUDA_MAJOR_VERSION >= 12)
...
...
src/support/ffi_aliases.h
0 → 100644
View file @
bbbf4207
#pragma once
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/string.h>
namespace
tvm
{
using
ffi
::
Array
;
using
ffi
::
Function
;
using
ffi
::
Map
;
using
ffi
::
Optional
;
using
ffi
::
String
;
}
// namespace tvm
src/target/codegen_cpp.cc
View file @
bbbf4207
...
...
@@ -29,6 +29,7 @@
#include <unordered_set>
#include <utility>
#include "../support/ffi_aliases.h"
#include "support/str_escape.h"
#include "target/build_common.h"
#include "target/source/codegen_params.h"
...
...
@@ -54,8 +55,7 @@ void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts,
}
void
CodeGenTileLangCPP
::
InitGlobalContext
()
{
decl_stream
<<
"void* "
<<
tvm
::
runtime
::
symbol
::
tvm_ffi_library_ctx
<<
" = NULL;
\n
"
;
decl_stream
<<
"void* "
<<
ffi
::
symbol
::
tvm_ffi_library_ctx
<<
" = NULL;
\n
"
;
}
void
CodeGenTileLangCPP
::
DefineModuleName
()
{
...
...
@@ -256,8 +256,8 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) {
// reserve keywords
ReserveKeywordsAsUnique
();
auto
global_symbol
=
f
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
ICHECK
(
global_symbol
.
defined
()
)
auto
global_symbol
=
f
->
GetAttr
<
ffi
::
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
ICHECK
(
global_symbol
)
<<
"CodeGenC: Expect PrimFunc to have the global_symbol attribute"
;
bool
no_alias
=
f
->
HasNonzeroAttr
(
tir
::
attr
::
kNoAlias
);
...
...
src/target/codegen_cpp.h
View file @
bbbf4207
...
...
@@ -73,10 +73,10 @@ public:
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
;
// NOLINT(*)
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
// NOLINT(*)
void
GenerateForwardFunctionDeclarations
(
String
global_symbol
,
const
Array
<
Type
>
&
arg_types
,
void
GenerateForwardFunctionDeclarations
(
ffi
::
String
global_symbol
,
const
ffi
::
Array
<
Type
>
&
arg_types
,
const
Type
&
ret_type
)
override
;
Array
<
String
>
GetFunctionNames
()
{
return
function_names_
;
}
ffi
::
Array
<
ffi
::
String
>
GetFunctionNames
()
{
return
function_names_
;
}
private:
/* \brief Internal structure to store information about function calls */
...
...
@@ -92,7 +92,7 @@ private:
/* \brief mapping global packed func to the unique name */
std
::
unordered_map
<
std
::
string
,
std
::
string
>
declared_globals_
;
/* \brief names of the functions declared in this module */
Array
<
String
>
function_names_
;
ffi
::
Array
<
ffi
::
String
>
function_names_
;
/*! \brief whether to emit asserts in the resulting C code */
bool
emit_asserts_
;
/*! \brief whether to emit forward function declarations in the resulting C
...
...
src/target/codegen_cuda.cc
View file @
bbbf4207
...
...
@@ -20,6 +20,7 @@
namespace
tvm
{
namespace
codegen
{
using
namespace
tvm
::
tl
::
codegen
;
using
namespace
ffi
;
struct
CUDAMath
{
std
::
string
operator
()(
DataType
t
,
std
::
string
name
)
const
{
...
...
@@ -259,6 +260,21 @@ std::string CodeGenTileLangCUDA::Finish() {
if
(
need_mma_h_
)
{
decl_stream
<<
"#include <mma.h>
\n
"
;
}
if
(
need_mma_instruction_h_
)
{
decl_stream
<<
"#include <tl_templates/cuda/instruction/mma.h>
\n
"
;
}
if
(
need_wgmma_instruction_h_
)
{
decl_stream
<<
"#include <tl_templates/cuda/instruction/wgmma.h>
\n
"
;
}
if
(
need_tcgen05mma_instruction_h_
)
{
decl_stream
<<
"#include <tl_templates/cuda/instruction/tcgen05mma.h>
\n
"
;
}
if
(
need_mma_sm70_instruction_h_
)
{
decl_stream
<<
"#include <tl_templates/cuda/instruction/mma_sm70.h>
\n
"
;
}
if
(
need_tcgen05_common_h_
)
{
decl_stream
<<
"#include <tl_templates/cuda/tcgen_05.h>
\n
"
;
}
if
(
enable_fp8_
)
{
decl_stream
<<
"#include <tl_templates/cuda/cuda_fp8.h>
\n
"
;
}
...
...
@@ -919,6 +935,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<<
"__half22float2(*((half2*)(&("
<<
src
<<
"))+1));
\n
"
;
os
<<
sret
;
return
;
}
else
if
(
from_ty
.
lanes
()
==
8
&&
target_ty
.
lanes
()
==
8
)
{
// half8 -> float8
PrintIndent
();
stream
<<
"((float2*)(&"
<<
sret
<<
"))[0] = "
<<
"__half22float2(*(half2*)(&("
<<
src
<<
")));
\n
"
;
PrintIndent
();
stream
<<
"((float2*)(&"
<<
sret
<<
"))[1] = "
<<
"__half22float2(*((half2*)(&("
<<
src
<<
"))+1));
\n
"
;
PrintIndent
();
stream
<<
"((float2*)(&"
<<
sret
<<
"))[2] = "
<<
"__half22float2(*((half2*)(&("
<<
src
<<
"))+2));
\n
"
;
PrintIndent
();
stream
<<
"((float2*)(&"
<<
sret
<<
"))[3] = "
<<
"__half22float2(*((half2*)(&("
<<
src
<<
"))+3));
\n
"
;
os
<<
sret
;
return
;
}
}
else
if
(
from_ty
.
is_float
()
&&
target_ty
.
is_float16
())
{
// Use __float22half2_rn for vectorized conversion (float2 -> half2)
...
...
@@ -939,6 +971,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<<
"__float22half2_rn(*((float2*)(&("
<<
src
<<
"))+1));
\n
"
;
os
<<
sret
;
return
;
}
else
if
(
from_ty
.
lanes
()
==
8
&&
target_ty
.
lanes
()
==
8
)
{
// float8 -> half8
PrintIndent
();
stream
<<
"((half2*)(&"
<<
sret
<<
"))[0] = "
<<
"__float22half2_rn(*(float2*)(&("
<<
src
<<
")));
\n
"
;
PrintIndent
();
stream
<<
"((half2*)(&"
<<
sret
<<
"))[1] = "
<<
"__float22half2_rn(*((float2*)(&("
<<
src
<<
"))+1));
\n
"
;
PrintIndent
();
stream
<<
"((half2*)(&"
<<
sret
<<
"))[2] = "
<<
"__float22half2_rn(*((float2*)(&("
<<
src
<<
"))+2));
\n
"
;
PrintIndent
();
stream
<<
"((half2*)(&"
<<
sret
<<
"))[3] = "
<<
"__float22half2_rn(*((float2*)(&("
<<
src
<<
"))+3));
\n
"
;
os
<<
sret
;
return
;
}
}
...
...
@@ -965,6 +1013,26 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<<
src
<<
"))+1));
\n
"
;
os
<<
sret
;
return
;
}
else
if
(
from_ty
.
lanes
()
==
8
&&
target_ty
.
lanes
()
==
8
)
{
// bfloat162x4 -> float8
PrintIndent
();
stream
<<
"((float2*)(&"
<<
sret
<<
"))[0] = "
<<
"__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
<<
src
<<
")));
\n
"
;
PrintIndent
();
stream
<<
"((float2*)(&"
<<
sret
<<
"))[1] = "
<<
"__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<<
src
<<
"))+1));
\n
"
;
PrintIndent
();
stream
<<
"((float2*)(&"
<<
sret
<<
"))[2] = "
<<
"__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<<
src
<<
"))+2));
\n
"
;
PrintIndent
();
stream
<<
"((float2*)(&"
<<
sret
<<
"))[3] = "
<<
"__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<<
src
<<
"))+3));
\n
"
;
os
<<
sret
;
return
;
}
}
else
if
(
from_ty
.
is_float
()
&&
target_ty
.
is_bfloat16
())
{
// Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
...
...
@@ -985,6 +1053,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<<
"__float22bfloat162_rn(*((float2*)(&("
<<
src
<<
"))+1));
\n
"
;
os
<<
sret
;
return
;
}
else
if
(
from_ty
.
lanes
()
==
8
&&
target_ty
.
lanes
()
==
8
)
{
// float8 -> bfloat162x4
PrintIndent
();
stream
<<
"(reinterpret_cast<__nv_bfloat162*>(&"
<<
sret
<<
"))[0] = "
<<
"__float22bfloat162_rn(*(float2*)(&("
<<
src
<<
")));
\n
"
;
PrintIndent
();
stream
<<
"(reinterpret_cast<__nv_bfloat162*>(&"
<<
sret
<<
"))[1] = "
<<
"__float22bfloat162_rn(*((float2*)(&("
<<
src
<<
"))+1));
\n
"
;
PrintIndent
();
stream
<<
"(reinterpret_cast<__nv_bfloat162*>(&"
<<
sret
<<
"))[2] = "
<<
"__float22bfloat162_rn(*((float2*)(&("
<<
src
<<
"))+2));
\n
"
;
PrintIndent
();
stream
<<
"(reinterpret_cast<__nv_bfloat162*>(&"
<<
sret
<<
"))[3] = "
<<
"__float22bfloat162_rn(*((float2*)(&("
<<
src
<<
"))+3));
\n
"
;
os
<<
sret
;
return
;
}
}
...
...
@@ -1017,6 +1101,36 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<<
"))+1), __NV_SATFINITE, "
<<
(
target_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
os
<<
sret
;
return
;
}
else
if
(
from_ty
.
lanes
()
==
8
&&
target_ty
.
lanes
()
==
8
)
{
// float8 -> fp8x8
PrintIndent
();
stream
<<
"((__nv_fp8x2_storage_t*)(&"
<<
sret
<<
"))[0] = "
<<
"__nv_cvt_float2_to_fp8x2(*(float2*)(&("
<<
src
<<
")), __NV_SATFINITE, "
<<
(
target_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"((__nv_fp8x2_storage_t*)(&"
<<
sret
<<
"))[1] = "
<<
"__nv_cvt_float2_to_fp8x2(*((float2*)(&("
<<
src
<<
"))+1), __NV_SATFINITE, "
<<
(
target_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"((__nv_fp8x2_storage_t*)(&"
<<
sret
<<
"))[2] = "
<<
"__nv_cvt_float2_to_fp8x2(*((float2*)(&("
<<
src
<<
"))+2), __NV_SATFINITE, "
<<
(
target_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"((__nv_fp8x2_storage_t*)(&"
<<
sret
<<
"))[3] = "
<<
"__nv_cvt_float2_to_fp8x2(*((float2*)(&("
<<
src
<<
"))+3), __NV_SATFINITE, "
<<
(
target_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
os
<<
sret
;
return
;
}
}
...
...
@@ -1034,6 +1148,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
os
<<
sret
;
}
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
MinNode
*
op
,
std
::
ostream
&
os
)
{
// TODO(wt): Consider vectorized reduction and impl for other dtypes
DataType
t
=
op
->
dtype
;
// Standard min/max functions don't support bfloat16 or float16
if
((
t
.
is_bfloat16
()
||
t
.
is_float16
())
&&
t
.
is_scalar
())
{
os
<<
"cutlass::fast_min("
<<
PrintExpr
(
op
->
a
)
<<
", "
<<
PrintExpr
(
op
->
b
)
<<
")"
;
return
;
}
// For float32 and float64 scalar, use standard min functions
if
(
t
.
is_float
()
&&
t
.
is_scalar
())
{
if
(
t
.
bits
()
==
32
||
t
.
bits
()
==
64
)
{
os
<<
"min("
<<
PrintExpr
(
op
->
a
)
<<
", "
<<
PrintExpr
(
op
->
b
)
<<
")"
;
return
;
}
}
// For all other scalar types (int, uint), use default implementation
CodeGenC
::
VisitExpr_
(
op
,
os
);
}
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
MaxNode
*
op
,
std
::
ostream
&
os
)
{
// TODO(wt): Consider vectorized reduction and impl for other dtypes
DataType
t
=
op
->
dtype
;
// Standard min/max functions don't support bfloat16 or float16
if
((
t
.
is_bfloat16
()
||
t
.
is_float16
())
&&
t
.
is_scalar
())
{
os
<<
"cutlass::fast_max("
<<
PrintExpr
(
op
->
a
)
<<
", "
<<
PrintExpr
(
op
->
b
)
<<
")"
;
return
;
}
// For float32 and float64 scalar, use standard max functions
if
(
t
.
is_float
()
&&
t
.
is_scalar
())
{
if
(
t
.
bits
()
==
32
||
t
.
bits
()
==
64
)
{
os
<<
"max("
<<
PrintExpr
(
op
->
a
)
<<
", "
<<
PrintExpr
(
op
->
b
)
<<
")"
;
return
;
}
}
// For all other scalar types (int, uint), use default implementation
CodeGenC
::
VisitExpr_
(
op
,
os
);
}
void
CodeGenTileLangCUDA
::
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
...
...
@@ -1132,7 +1292,7 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
if
(
scope
.
empty
())
{
scope
=
GetPtrStorageScope
(
buffer
->
data
);
}
if
(
scope
==
"local.var"
||
scope
==
"local.descriptor"
)
{
if
(
scope
==
"local.var"
||
scope
.
find
(
"local.descriptor"
)
==
0
)
{
os
<<
vid
;
return
os
.
str
();
}
...
...
@@ -1452,6 +1612,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
int
num_mma
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
this
->
stream
<<
"tl::warpgroup_wait<"
<<
std
::
to_string
(
num_mma
)
<<
">();
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
warpgroup_fence_operand
()))
{
ICHECK_EQ
(
op
->
args
.
size
(),
4U
);
std
::
string
dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
data_ptr
=
this
->
PrintExpr
(
op
->
args
[
1
]);
std
::
string
offset
=
this
->
PrintExpr
(
op
->
args
[
2
]);
std
::
string
num_regs
=
this
->
PrintExpr
(
op
->
args
[
3
]);
auto
dtype_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
dtype
);
std
::
string
cast_type
=
"uint32_t"
;
if
(
dtype_enum
==
tl
::
codegen
::
ptx
::
DataType
::
kFloat32
||
dtype_enum
==
tl
::
codegen
::
ptx
::
DataType
::
kTensorFloat32
)
{
cast_type
=
"float"
;
}
this
->
PrintIndent
();
this
->
stream
<<
"tl::warpgroup_fence_operand(reinterpret_cast<"
<<
cast_type
<<
"*>("
<<
data_ptr
<<
" + "
<<
offset
<<
"), "
<<
num_regs
<<
");
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
set_max_nreg
()))
{
this
->
PrintIndent
();
int
nreg
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
...
...
@@ -1563,14 +1739,124 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std
::
string
b_bias
=
this
->
PrintExpr
(
op
->
args
[
9
]);
std
::
string
c_ref
=
this
->
PrintExpr
(
op
->
args
[
10
]);
std
::
string
c_bias
=
this
->
PrintExpr
(
op
->
args
[
11
]);
bool
saturate
=
Downcast
<
Bool
>
(
op
->
args
[
12
])
->
value
;
std
::
string
bit_op
=
op
->
args
.
size
()
>
13
?
Downcast
<
StringImm
>
(
op
->
args
[
13
])
->
value
:
""
;
std
::
string
asm_code
=
PrintMMAAssembly
(
shape
,
A_layout
,
B_layout
,
A_dtype
,
B_dtype
,
C_dtype
,
a_ref
,
a_bias
,
b_ref
,
b_bias
,
c_ref
,
c_bias
,
""
,
""
,
""
,
bit_op
,
false
,
saturate
)
;
auto
dtype_a_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
A_dtype
)
;
auto
dtype_b_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
B_dtype
);
auto
dtype_c_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
C_dtype
)
;
auto
[
m
,
n
,
k
]
=
tl
::
codegen
::
ptx
::
ParseMMAShape
(
shape
);
need_mma_instruction_h_
=
true
;
this
->
PrintIndent
();
this
->
stream
<<
asm_code
;
std
::
string
mma_call
=
"tl::mma_sync<(AType), (BType), (CType), (M), (N), (K), (TransA), "
"(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), "
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));
\n
"
;
tl
::
codegen
::
Replacer
replacer
;
// TODO(lei): Type Workaround for TF32, should be removed when
// we introduced tfloat32_t in the frontend.
std
::
string
AType
=
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
dtype_a_enum
);
if
(
AType
==
"tl::DataType::kFloat32"
)
{
AType
=
"tl::DataType::kTensorFloat32"
;
}
std
::
string
BType
=
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
dtype_b_enum
);
if
(
BType
==
"tl::DataType::kFloat32"
)
{
BType
=
"tl::DataType::kTensorFloat32"
;
}
std
::
string
ARegType
=
tl
::
codegen
::
GetMMARegisterType
(
dtype_a_enum
);
if
(
ARegType
==
"float"
)
{
ARegType
=
"uint32_t"
;
}
std
::
string
BRegType
=
tl
::
codegen
::
GetMMARegisterType
(
dtype_b_enum
);
if
(
BRegType
==
"float"
)
{
BRegType
=
"uint32_t"
;
}
replacer
.
register_rule
(
"(AType)"
,
AType
);
replacer
.
register_rule
(
"(BType)"
,
BType
);
replacer
.
register_rule
(
"(CType)"
,
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
dtype_c_enum
));
replacer
.
register_rule
(
"(M)"
,
std
::
to_string
(
m
));
replacer
.
register_rule
(
"(N)"
,
std
::
to_string
(
n
));
replacer
.
register_rule
(
"(K)"
,
std
::
to_string
(
k
));
replacer
.
register_rule
(
"(TransA)"
,
A_layout
==
"row"
?
"false"
:
"true"
);
replacer
.
register_rule
(
"(TransB)"
,
B_layout
==
"row"
?
"false"
:
"true"
);
replacer
.
register_rule
(
"(ARegType)"
,
ARegType
);
replacer
.
register_rule
(
"(BRegType)"
,
BRegType
);
replacer
.
register_rule
(
"(CRegType)"
,
tl
::
codegen
::
GetMMARegisterType
(
dtype_c_enum
));
replacer
.
register_rule
(
"(A_ptr)"
,
a_ref
);
replacer
.
register_rule
(
"(A_offset)"
,
a_bias
);
replacer
.
register_rule
(
"(B_ptr)"
,
b_ref
);
replacer
.
register_rule
(
"(B_offset)"
,
b_bias
);
replacer
.
register_rule
(
"(C_ptr)"
,
c_ref
);
replacer
.
register_rule
(
"(C_offset)"
,
c_bias
);
this
->
stream
<<
replacer
.
rewrite
(
mma_call
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
ptx_mma_sm70
()))
{
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: fp16
// arg 4: B precision: fp16
// arg 5: C precision: fp16, fp32
// arg 6: A multiplicand
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 11: C accumulator index
// arg 12: saturate
ICHECK_EQ
(
op
->
args
.
size
(),
12U
);
std
::
string
shape
=
Downcast
<
StringImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
A_layout
=
Downcast
<
StringImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
B_layout
=
Downcast
<
StringImm
>
(
op
->
args
[
2
])
->
value
;
std
::
string
A_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
3
])
->
value
;
std
::
string
B_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
4
])
->
value
;
std
::
string
C_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
5
])
->
value
;
std
::
string
a_ref
=
this
->
PrintExpr
(
op
->
args
[
6
]);
std
::
string
a_bias
=
this
->
PrintExpr
(
op
->
args
[
7
]);
std
::
string
b_ref
=
this
->
PrintExpr
(
op
->
args
[
8
]);
std
::
string
b_bias
=
this
->
PrintExpr
(
op
->
args
[
9
]);
std
::
string
c_ref
=
this
->
PrintExpr
(
op
->
args
[
10
]);
std
::
string
c_bias
=
this
->
PrintExpr
(
op
->
args
[
11
]);
auto
dtype_a_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
A_dtype
);
auto
dtype_b_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
B_dtype
);
auto
dtype_c_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
C_dtype
);
auto
[
m
,
n
,
k
]
=
tl
::
codegen
::
ptx
::
ParseMMAShape
(
shape
);
need_mma_sm70_instruction_h_
=
true
;
this
->
PrintIndent
();
std
::
string
mma_call
=
"tl::mma_sync_sm70<(AType), (BType), (CType), (M), (N), (K), (TransA), "
"(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), "
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));
\n
"
;
tl
::
codegen
::
Replacer
replacer
;
replacer
.
register_rule
(
"(AType)"
,
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
dtype_a_enum
));
replacer
.
register_rule
(
"(BType)"
,
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
dtype_b_enum
));
replacer
.
register_rule
(
"(CType)"
,
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
dtype_c_enum
));
replacer
.
register_rule
(
"(M)"
,
std
::
to_string
(
m
));
replacer
.
register_rule
(
"(N)"
,
std
::
to_string
(
n
));
replacer
.
register_rule
(
"(K)"
,
std
::
to_string
(
k
));
replacer
.
register_rule
(
"(TransA)"
,
A_layout
==
"row"
?
"false"
:
"true"
);
replacer
.
register_rule
(
"(TransB)"
,
B_layout
==
"row"
?
"false"
:
"true"
);
replacer
.
register_rule
(
"(ARegType)"
,
tl
::
codegen
::
GetMMARegisterType
(
dtype_a_enum
));
replacer
.
register_rule
(
"(BRegType)"
,
tl
::
codegen
::
GetMMARegisterType
(
dtype_b_enum
));
replacer
.
register_rule
(
"(CRegType)"
,
tl
::
codegen
::
GetMMARegisterType
(
dtype_c_enum
));
replacer
.
register_rule
(
"(A_ptr)"
,
a_ref
);
replacer
.
register_rule
(
"(A_offset)"
,
a_bias
);
replacer
.
register_rule
(
"(B_ptr)"
,
b_ref
);
replacer
.
register_rule
(
"(B_offset)"
,
b_bias
);
replacer
.
register_rule
(
"(C_ptr)"
,
c_ref
);
replacer
.
register_rule
(
"(C_offset)"
,
c_bias
);
this
->
stream
<<
replacer
.
rewrite
(
mma_call
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_mma_sp
()))
{
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
...
...
@@ -1636,27 +1922,32 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std
::
string
B_offset
=
this
->
PrintExpr
(
op
->
args
[
9
]);
std
::
string
c_ref
=
this
->
PrintExpr
(
op
->
args
[
10
]);
std
::
string
c_offset
=
this
->
PrintExpr
(
op
->
args
[
11
]);
bool
scale_out
=
Downcast
<
Bool
>
(
op
->
args
[
12
])
->
value
;
std
::
string
scale_out
=
this
->
PrintExpr
(
op
->
args
[
12
]);
bool
scale_in_a
=
Downcast
<
Bool
>
(
op
->
args
[
13
])
->
value
;
bool
scale_in_b
=
Downcast
<
Bool
>
(
op
->
args
[
14
])
->
value
;
const
bool
a_is_shared
=
true
;
this
->
PrintIndent
();
std
::
string
asm_code
=
PrintWGMMAAssembly
(
shape
,
a_is_k_major
,
b_is_k_major
,
A_dtype
,
B_dtype
,
C_dtype
,
a_desc
,
A_offset
,
b_desc
,
B_offset
,
c_ref
,
c_offset
,
scale_out
,
scale_in_a
,
scale_in_b
,
a_is_shared
,
""
,
""
,
""
,
false
);
auto
[
m
,
n
,
k
]
=
tl
::
codegen
::
ptx
::
ParseMMAShape
(
shape
);
need_wgmma_instruction_h_
=
true
;
std
::
string
wgmma_asm_code
=
"tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
"(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));
\n
"
;
// replace patterns
tl
::
codegen
::
Replacer
replacer
;
replacer
.
register_rule
(
"(AType)"
,
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
A_dtype
));
replacer
.
register_rule
(
"(BType)"
,
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
B_dtype
));
std
::
string
AType
=
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
A_dtype
);
if
(
AType
==
"tl::DataType::kFloat32"
)
{
AType
=
"tl::DataType::kTensorFloat32"
;
}
std
::
string
BType
=
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
B_dtype
);
if
(
BType
==
"tl::DataType::kFloat32"
)
{
BType
=
"tl::DataType::kTensorFloat32"
;
}
replacer
.
register_rule
(
"(AType)"
,
AType
);
replacer
.
register_rule
(
"(BType)"
,
BType
);
replacer
.
register_rule
(
"(CType)"
,
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
C_dtype
));
replacer
.
register_rule
(
"(M)"
,
std
::
to_string
(
m
));
...
...
@@ -1671,45 +1962,184 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer
.
register_rule
(
"(desc_b)"
,
b_desc
);
replacer
.
register_rule
(
"(B_offset)"
,
B_offset
);
replacer
.
register_rule
(
"(C)"
,
c_ref
+
" + "
+
c_offset
);
replacer
.
register_rule
(
"(scale_out)"
,
scale_out
?
"true"
:
"false"
);
replacer
.
register_rule
(
"(scale_out)"
,
scale_out
);
wgmma_asm_code
=
replacer
.
rewrite
(
wgmma_asm_code
);
this
->
stream
<<
wgmma_asm_code
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
ptx_wgmma_rs
()))
{
// arg 0:
dty
pe
// arg 1:
shape
// arg 2: A_
layout
// arg 3: B_
layout
// arg 4:
A
_dtype
// arg 5:
B_dtype
// arg 6:
C_dtype
// arg 7: multiplicand_
a
// arg 8: multiplicand_b
// arg 0:
sha
pe
// arg 1:
B_layout
// arg 2: A_
dtype
// arg 3: B_
dtype
// arg 4:
C
_dtype
// arg 5:
multiplicand_a
// arg 6:
multiplicand_a offset
// arg 7: multiplicand_
b descriptor
// arg 8: multiplicand_b
offset
// arg 9: accumulator
// arg 10: saturate
ICHECK_EQ
(
op
->
args
.
size
(),
15U
)
<<
"ptx_wgmma_rs args is "
<<
op
->
args
;
// arg 10: accumulator offset
// arg 11: scale_out
// arg 12: scale_in_a
// arg 13: scale_in_b
ICHECK_EQ
(
op
->
args
.
size
(),
14U
)
<<
"ptx_wgmma_rs args is "
<<
op
->
args
;
std
::
string
shape
=
Downcast
<
StringImm
>
(
op
->
args
[
0
])
->
value
;
bool
A_layout
=
Downcast
<
Bool
>
(
op
->
args
[
1
])
->
value
;
bool
B_layout
=
Downcast
<
Bool
>
(
op
->
args
[
2
])
->
value
;
std
::
string
A_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
3
])
->
value
;
std
::
string
B_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
4
])
->
value
;
std
::
string
C_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
5
])
->
value
;
std
::
string
a_ref
=
this
->
PrintExpr
(
op
->
args
[
6
]);
std
::
string
A_offset
=
this
->
PrintExpr
(
op
->
args
[
7
]);
std
::
string
b_desc
=
this
->
PrintExpr
(
op
->
args
[
8
]);
std
::
string
B_offset
=
this
->
PrintExpr
(
op
->
args
[
9
]);
std
::
string
c_ref
=
this
->
PrintExpr
(
op
->
args
[
10
]);
std
::
string
c_offset
=
this
->
PrintExpr
(
op
->
args
[
11
]);
bool
scale_out
=
Downcast
<
Bool
>
(
op
->
args
[
12
])
->
value
;
bool
scale_in_a
=
Downcast
<
Bool
>
(
op
->
args
[
13
])
->
value
;
bool
scale_in_b
=
Downcast
<
Bool
>
(
op
->
args
[
14
])
->
value
;
bool
b_is_k_major
=
Downcast
<
Bool
>
(
op
->
args
[
1
])
->
value
;
std
::
string
A_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
2
])
->
value
;
std
::
string
B_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
3
])
->
value
;
std
::
string
C_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
4
])
->
value
;
std
::
string
a_ref
=
this
->
PrintExpr
(
op
->
args
[
5
]);
std
::
string
A_offset
=
this
->
PrintExpr
(
op
->
args
[
6
]);
std
::
string
b_desc
=
this
->
PrintExpr
(
op
->
args
[
7
]);
std
::
string
B_offset
=
this
->
PrintExpr
(
op
->
args
[
8
]);
std
::
string
c_ref
=
this
->
PrintExpr
(
op
->
args
[
9
]);
std
::
string
c_offset
=
this
->
PrintExpr
(
op
->
args
[
10
]);
std
::
string
scale_out
=
this
->
PrintExpr
(
op
->
args
[
11
]);
bool
scale_in_a
=
Downcast
<
Bool
>
(
op
->
args
[
12
])
->
value
;
bool
scale_in_b
=
Downcast
<
Bool
>
(
op
->
args
[
13
])
->
value
;
auto
dtype_a_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
A_dtype
);
auto
dtype_b_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
B_dtype
);
auto
dtype_c_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
C_dtype
);
auto
[
m
,
n
,
k
]
=
tl
::
codegen
::
ptx
::
ParseMMAShape
(
shape
);
const
bool
a_is_shared
=
fals
e
;
need_wgmma_instruction_h_
=
tru
e
;
this
->
PrintIndent
();
std
::
string
asm_code
=
PrintWGMMAAssembly
(
shape
,
A_layout
,
B_layout
,
A_dtype
,
B_dtype
,
C_dtype
,
a_ref
,
A_offset
,
b_desc
,
B_offset
,
c_ref
,
c_offset
,
scale_out
,
scale_in_a
,
scale_in_b
,
a_is_shared
,
""
,
""
,
""
,
false
);
this
->
stream
<<
asm_code
;
std
::
string
wgmma_call
=
"tl::wgmma_rs<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
"(tnspB), (scaleA), (scaleB)>(reinterpret_cast<const "
"uint32_t*>((A_ptr) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), "
"reinterpret_cast<uint32_t*>((C_ptr) + (C_offset)), "
"(scale_out));
\n
"
;
tl
::
codegen
::
Replacer
replacer
;
std
::
string
AType
=
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
A_dtype
);
if
(
AType
==
"tl::DataType::kFloat32"
)
{
AType
=
"tl::DataType::kTensorFloat32"
;
}
std
::
string
BType
=
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
B_dtype
);
if
(
BType
==
"tl::DataType::kFloat32"
)
{
BType
=
"tl::DataType::kTensorFloat32"
;
}
replacer
.
register_rule
(
"(AType)"
,
AType
);
replacer
.
register_rule
(
"(BType)"
,
BType
);
replacer
.
register_rule
(
"(CType)"
,
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
dtype_c_enum
));
replacer
.
register_rule
(
"(M)"
,
std
::
to_string
(
m
));
replacer
.
register_rule
(
"(N)"
,
std
::
to_string
(
n
));
replacer
.
register_rule
(
"(K)"
,
std
::
to_string
(
k
));
replacer
.
register_rule
(
"(tnspA)"
,
"false"
);
replacer
.
register_rule
(
"(tnspB)"
,
b_is_k_major
?
"false"
:
"true"
);
replacer
.
register_rule
(
"(scaleA)"
,
scale_in_a
?
"1"
:
"-1"
);
replacer
.
register_rule
(
"(scaleB)"
,
scale_in_b
?
"1"
:
"-1"
);
replacer
.
register_rule
(
"(A_ptr)"
,
a_ref
);
replacer
.
register_rule
(
"(A_offset)"
,
A_offset
);
replacer
.
register_rule
(
"(desc_b)"
,
b_desc
);
replacer
.
register_rule
(
"(B_offset)"
,
B_offset
);
replacer
.
register_rule
(
"(C_ptr)"
,
c_ref
);
replacer
.
register_rule
(
"(C_offset)"
,
c_offset
);
replacer
.
register_rule
(
"(scale_out)"
,
scale_out
);
wgmma_call
=
replacer
.
rewrite
(
wgmma_call
);
this
->
stream
<<
wgmma_call
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
ptx_tcgen05_mma_ss
()))
{
ICHECK_EQ
(
op
->
args
.
size
(),
14U
)
<<
"ptx_tcgen05_mma_ss args is "
<<
op
->
args
;
std
::
string
C_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
a_desc
=
this
->
PrintExpr
(
op
->
args
[
1
]);
std
::
string
A_offset
=
this
->
PrintExpr
(
op
->
args
[
2
]);
std
::
string
b_desc
=
this
->
PrintExpr
(
op
->
args
[
3
]);
std
::
string
B_offset
=
this
->
PrintExpr
(
op
->
args
[
4
]);
std
::
string
c_ref
=
this
->
PrintExpr
(
op
->
args
[
5
]);
std
::
string
c_offset
=
this
->
PrintExpr
(
op
->
args
[
6
]);
PrimExpr
desc_expr
=
op
->
args
[
7
];
std
::
string
scale_out
=
this
->
PrintExpr
(
op
->
args
[
8
]);
std
::
string
mask0
=
this
->
PrintExpr
(
op
->
args
[
9
]);
std
::
string
mask1
=
this
->
PrintExpr
(
op
->
args
[
10
]);
std
::
string
mask2
=
this
->
PrintExpr
(
op
->
args
[
11
]);
std
::
string
mask3
=
this
->
PrintExpr
(
op
->
args
[
12
]);
bool
enable_ws
=
Downcast
<
Bool
>
(
op
->
args
[
13
])
->
value
;
auto
dtype_c_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
C_dtype
);
need_tcgen05mma_instruction_h_
=
true
;
this
->
PrintIndent
();
std
::
string
tcgen05_call
=
"tl::(tcgen05_name)<(CType)>(uint64_t((desc_a) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), (*reinterpret_cast<uint32_t*>((C))) "
"+ (C_offset), "
"(scale_out), static_cast<uint32_t>((desc_val)), (mask0), (mask1), "
"(mask2), (mask3));
\n
"
;
tl
::
codegen
::
Replacer
replacer
;
replacer
.
register_rule
(
"(CType)"
,
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
dtype_c_enum
));
replacer
.
register_rule
(
"(desc_a)"
,
a_desc
);
replacer
.
register_rule
(
"(A_offset)"
,
A_offset
);
replacer
.
register_rule
(
"(desc_b)"
,
b_desc
);
replacer
.
register_rule
(
"(B_offset)"
,
B_offset
);
replacer
.
register_rule
(
"(C)"
,
c_ref
);
replacer
.
register_rule
(
"(C_offset)"
,
c_offset
);
replacer
.
register_rule
(
"(tcgen05_name)"
,
enable_ws
?
"tcgen05mma_ws_ss"
:
"tcgen05mma_ss"
);
replacer
.
register_rule
(
"(scale_out)"
,
scale_out
);
replacer
.
register_rule
(
"(desc_val)"
,
this
->
PrintExpr
(
desc_expr
));
replacer
.
register_rule
(
"(mask0)"
,
mask0
);
replacer
.
register_rule
(
"(mask1)"
,
mask1
);
replacer
.
register_rule
(
"(mask2)"
,
mask2
);
replacer
.
register_rule
(
"(mask3)"
,
mask3
);
tcgen05_call
=
replacer
.
rewrite
(
tcgen05_call
);
this
->
stream
<<
tcgen05_call
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
ptx_tcgen05_mma_ts
()))
{
// TS: A from TMEM, B from SMEM (desc)
ICHECK_EQ
(
op
->
args
.
size
(),
13U
)
<<
"ptx_tcgen05_mma_ts args is "
<<
op
->
args
;
std
::
string
kind_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
a_ref
=
this
->
PrintExpr
(
op
->
args
[
1
]);
std
::
string
A_offset
=
this
->
PrintExpr
(
op
->
args
[
2
]);
std
::
string
b_desc
=
this
->
PrintExpr
(
op
->
args
[
3
]);
std
::
string
B_offset
=
this
->
PrintExpr
(
op
->
args
[
4
]);
std
::
string
c_ref
=
this
->
PrintExpr
(
op
->
args
[
5
]);
std
::
string
c_offset
=
this
->
PrintExpr
(
op
->
args
[
6
]);
PrimExpr
desc_expr
=
op
->
args
[
7
];
std
::
string
scale_out
=
this
->
PrintExpr
(
op
->
args
[
8
]);
std
::
string
mask0
=
this
->
PrintExpr
(
op
->
args
[
9
]);
std
::
string
mask1
=
this
->
PrintExpr
(
op
->
args
[
10
]);
std
::
string
mask2
=
this
->
PrintExpr
(
op
->
args
[
11
]);
std
::
string
mask3
=
this
->
PrintExpr
(
op
->
args
[
12
]);
auto
dtype_enum
=
tl
::
codegen
::
ptx
::
DTypeFromString
(
kind_dtype
);
need_tcgen05mma_instruction_h_
=
true
;
this
->
PrintIndent
();
std
::
string
tcgen05_call
=
"tl::tcgen05mma_ts<(CType)>( (*reinterpret_cast<uint32_t*>((A))) + "
"(A_offset), "
"uint64_t((desc_b) + (B_offset)), (*reinterpret_cast<uint32_t*>((C))) "
"+ (C_offset), "
"(scale_out), static_cast<uint32_t>((desc_val)), (mask0), (mask1), "
"(mask2), (mask3));
\n
"
;
tl
::
codegen
::
Replacer
replacer
;
replacer
.
register_rule
(
"(CType)"
,
tl
::
codegen
::
ptx
::
DTypeEnumToString
(
dtype_enum
));
replacer
.
register_rule
(
"(A)"
,
a_ref
);
replacer
.
register_rule
(
"(A_offset)"
,
A_offset
);
replacer
.
register_rule
(
"(desc_b)"
,
b_desc
);
replacer
.
register_rule
(
"(B_offset)"
,
B_offset
);
replacer
.
register_rule
(
"(C)"
,
c_ref
);
replacer
.
register_rule
(
"(C_offset)"
,
c_offset
);
replacer
.
register_rule
(
"(scale_out)"
,
scale_out
);
replacer
.
register_rule
(
"(desc_val)"
,
this
->
PrintExpr
(
desc_expr
));
replacer
.
register_rule
(
"(mask0)"
,
mask0
);
replacer
.
register_rule
(
"(mask1)"
,
mask1
);
replacer
.
register_rule
(
"(mask2)"
,
mask2
);
replacer
.
register_rule
(
"(mask3)"
,
mask3
);
tcgen05_call
=
replacer
.
rewrite
(
tcgen05_call
);
this
->
stream
<<
tcgen05_call
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
tcgen05_mma_arrive
()))
{
ICHECK_EQ
(
op
->
args
.
size
(),
1U
)
<<
"tcgen05_mma_arrive expects 1 argument"
;
need_tcgen05_common_h_
=
true
;
this
->
PrintIndent
();
this
->
stream
<<
"tl::tcgen05_mma_arrive("
<<
this
->
PrintExpr
(
op
->
args
[
0
])
<<
");
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_ldmatrix
()))
{
// arg 0: whether the matrix is loaded in column major format or not.
// arg 1: number of matrices to load.
...
...
@@ -2021,8 +2451,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
"A_ptr, B_ptr, C_ptr>, but got "
<<
op
->
args
.
size
();
auto
op_instance
=
Downcast
<
StringImm
>
(
op
->
args
[
0
]);
this
->
PrintCallExtern
(
GetType
(
GetRef
<
PrimExpr
>
(
op
)),
op_instance
->
value
,
op
->
args
,
true
,
os
);
this
->
PrintCallExtern
(
GetType
(
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
)),
op_instance
->
value
,
op
->
args
,
true
,
os
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
tl_gemm_sp
()))
{
ICHECK
(
op
->
args
.
size
()
==
5
)
<<
"tl_gemm_sp expects 5 arguments <op_instance, A_ptr, B_ptr, C_ptr, "
...
...
@@ -2030,8 +2460,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
<<
op
->
args
.
size
();
auto
op_instance
=
Downcast
<
StringImm
>
(
op
->
args
[
0
]);
enable_sparse_gemm_
=
true
;
this
->
PrintCallExtern
(
GetType
(
GetRef
<
PrimExpr
>
(
op
)),
op_instance
->
value
,
op
->
args
,
true
,
os
);
this
->
PrintCallExtern
(
GetType
(
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
)),
op_instance
->
value
,
op
->
args
,
true
,
os
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
get_lane_idx
()))
{
ICHECK_LE
(
op
->
args
.
size
(),
1
)
<<
"tl.get_lane_idx expects at most one argument <warp_size>."
;
...
...
@@ -2069,19 +2499,35 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
os
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
tl_shuffle_elect
()))
{
os
<<
"tl::tl_shuffle_elect<"
<<
PrintExpr
(
op
->
args
[
0
])
<<
">()"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
initialize_descriptor
()))
{
}
else
if
(
op
->
op
.
same_as
(
tl
::
initialize_
wgmma_
descriptor
()))
{
ICHECK
(
op
->
args
.
size
()
==
5
)
<<
"tl_initialize_descriptor expects 5 arguments but got "
<<
"tl_initialize_
wgmma_
descriptor expects 5 arguments but got "
<<
op
->
args
.
size
();
auto
descriptor
=
op
->
args
[
0
];
auto
start_address
=
op
->
args
[
1
];
auto
layout_type
=
op
->
args
[
2
];
auto
leading_byte_offset
=
op
->
args
[
3
];
auto
stride_byte_offset
=
op
->
args
[
4
];
os
<<
"tl::initialize_descriptor<"
<<
PrintExpr
(
layout_type
)
<<
", "
os
<<
"tl::initialize_
wgmma_
descriptor<"
<<
PrintExpr
(
layout_type
)
<<
", "
<<
PrintExpr
(
leading_byte_offset
)
<<
", "
<<
PrintExpr
(
stride_byte_offset
)
<<
">("
<<
PrintExpr
(
descriptor
)
<<
", "
<<
PrintExpr
(
start_address
)
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
initialize_tcgen05_descriptor
()))
{
ICHECK
(
op
->
args
.
size
()
==
7
)
<<
"tl_initialize_tcgen05_descriptor expects 7 arguments but got "
<<
op
->
args
.
size
();
auto
descriptor
=
op
->
args
[
0
];
auto
start_address
=
op
->
args
[
1
];
auto
leading_byte_offset
=
op
->
args
[
2
];
auto
stride_byte_offset
=
op
->
args
[
3
];
auto
base_offset
=
op
->
args
[
4
];
auto
leading_abs
=
op
->
args
[
5
];
auto
swizzle_mode
=
op
->
args
[
6
];
os
<<
"tl::initialize_tcgen05_descriptor("
<<
PrintExpr
(
descriptor
)
<<
", "
<<
PrintExpr
(
start_address
)
<<
", "
<<
PrintExpr
(
leading_byte_offset
)
<<
", "
<<
PrintExpr
(
stride_byte_offset
)
<<
", "
<<
PrintExpr
(
base_offset
)
<<
", "
<<
PrintExpr
(
leading_abs
)
<<
", "
<<
PrintExpr
(
swizzle_mode
)
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
increase_descriptor_offset
()))
{
ICHECK
(
op
->
args
.
size
()
==
2
)
<<
"tl_increase_descriptor_offset expects 2 arguments but got "
...
...
@@ -2232,8 +2678,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
<<
"Accumulator only support half, float and int type for now"
;
}
PrintWmmaScope
(
scope
,
op
->
dtype
,
buffer
,
stream
);
}
else
if
(
scope
==
"local.descriptor"
)
{
}
else
if
(
scope
==
"local.descriptor
.wgmma
"
)
{
stream
<<
"tl::GmmaDescriptor "
<<
vid
<<
";
\n
"
;
}
else
if
(
scope
==
"local.descriptor.tcgen05_smem"
)
{
stream
<<
"tl::Tcgen05SMemDescriptor "
<<
vid
<<
";
\n
"
;
}
else
if
(
scope
==
"local.descriptor.tcgen05_instr"
)
{
stream
<<
"tl::Tcgen05InstrDescriptor "
<<
vid
<<
";
\n
"
;
}
else
{
PrintStorageScope
(
scope
,
stream
);
PrintType
(
op
->
dtype
,
stream
);
...
...
@@ -2275,7 +2725,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
init
=
user_init
;
}
stream
<<
' '
<<
vid
<<
" = "
<<
PrintExpr
(
init
)
<<
";
\n
"
;
}
else
if
(
scope
!=
"local.descriptor"
)
{
}
else
if
(
scope
.
find
(
"local.descriptor"
)
!=
0
)
{
ICHECK
(
false
)
<<
"Unsupported scope: "
<<
scope
;
}
}
...
...
@@ -2297,6 +2747,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
stream
<<
" "
<<
vid_global_barrier_expect_
<<
" = 0;
\n
"
;
PrintIndent
();
stream
<<
"}
\n
"
;
}
if
(
call
&&
(
call
->
op
.
same_as
(
tvm
::
tl
::
device_assert
())))
{
std
::
string
cond
=
PrintExpr
(
call
->
args
[
0
]);
this
->
PrintIndent
();
stream
<<
"device_assert("
<<
cond
<<
");
\n
"
;
}
else
if
(
call
&&
call
->
op
.
same_as
(
tvm
::
tl
::
device_assert_with_msg
()))
{
std
::
string
cond
=
PrintExpr
(
call
->
args
[
0
]);
std
::
string
msg_expr
=
PrintExpr
(
call
->
args
[
1
]);
this
->
PrintIndent
();
stream
<<
"device_assert_with_msg("
<<
cond
<<
", "
<<
msg_expr
<<
");
\n
"
;
}
else
{
CodeGenC
::
VisitStmt_
(
op
);
}
...
...
@@ -2304,8 +2764,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
void
CodeGenTileLangCUDA
::
VisitExpr_
(
const
RampNode
*
op
,
std
::
ostream
&
os
)
{
int
lanes
=
static_cast
<
int
>
(
Downcast
<
IntImm
>
(
op
->
lanes
)
->
value
);
CHECK_LE
(
lanes
,
4
)
<<
"Translate Ramp Node "
<<
GetRef
<
Ramp
>
(
op
)
<<
" with "
<<
lanes
<<
" lanes is not allowed."
;
CHECK_LE
(
lanes
,
4
)
<<
"Translate Ramp Node "
<<
tvm
::
ffi
::
GetRef
<
Ramp
>
(
op
)
<<
" with "
<<
lanes
<<
" lanes is not allowed."
;
os
<<
"(make_"
;
PrintType
(
op
->
dtype
,
os
);
os
<<
"("
;
...
...
@@ -2540,12 +3000,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
inline
void
PrintConst
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
,
CodeGenTileLangCUDA
*
p
)
{
// NOLINT(*)
// Type code is kBFloat
if
(
op
->
dtype
.
is_bfloat16
())
{
os
<<
"bfloat16_t"
;
os
<<
'('
<<
std
::
hexfloat
<<
op
->
value
<<
'f'
;
os
<<
"/*"
<<
std
::
scientific
<<
op
->
value
<<
"*/"
;
os
<<
')'
;
// Type code is kBFloat/kFloat16
// which is indeed CUTLASS supported types currently
if
(
op
->
dtype
.
is_bfloat16
()
||
op
->
dtype
.
is_float16
())
{
std
::
ostringstream
temp
;
if
(
std
::
isinf
(
op
->
value
))
{
if
(
op
->
value
<
0
)
{
temp
<<
"-"
;
}
temp
<<
"std::numeric_limits<"
;
p
->
PrintType
(
op
->
dtype
,
temp
);
temp
<<
">::infinity()"
;
}
else
if
(
std
::
isnan
(
op
->
value
))
{
temp
<<
"std::numeric_limits<"
;
p
->
PrintType
(
op
->
dtype
,
temp
);
temp
<<
">::quiet_NaN()"
;
}
else
{
p
->
PrintType
(
op
->
dtype
,
temp
);
temp
<<
'('
<<
std
::
hexfloat
<<
op
->
value
<<
'f'
;
temp
<<
"/*"
<<
std
::
scientific
<<
op
->
value
<<
"*/"
;
temp
<<
')'
;
}
p
->
MarkConst
(
temp
.
str
());
os
<<
temp
.
str
();
return
;
}
// Type code is kFloat8_e5m2 or kE4M4Float
...
...
@@ -2556,7 +3033,7 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os
<<
')'
;
return
;
}
// Type code is kFloat
// Type code is kFloat
64/kFloat32 (kFloat16 is handled above)
switch
(
op
->
dtype
.
bits
())
{
case
64
:
case
32
:
{
...
...
@@ -2580,13 +3057,6 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os
<<
temp
.
str
();
break
;
}
case
16
:
{
os
<<
"half_t"
<<
'('
;
FloatImm
const_f32
=
FloatImm
(
DataType
::
Float
(
32
),
op
->
value
);
PrintConst
(
const_f32
.
get
(),
os
,
p
);
os
<<
')'
;
break
;
}
default:
LOG
(
FATAL
)
<<
"Bad bit-width for float: "
<<
op
->
dtype
<<
"
\n
"
;
}
...
...
@@ -2807,7 +3277,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
ReserveKeywordsAsUnique
();
auto
global_symbol
=
f
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
ICHECK
(
global_symbol
.
defined
()
)
ICHECK
(
global_symbol
)
<<
"CodeGenC: Expect PrimFunc to have the global_symbol attribute"
;
bool
no_alias
=
f
->
HasNonzeroAttr
(
tir
::
attr
::
kNoAlias
);
...
...
src/target/codegen_cuda.h
View file @
bbbf4207
...
...
@@ -51,6 +51,8 @@ public:
void
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
MinNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitExpr_
(
const
MaxNode
*
op
,
std
::
ostream
&
os
)
final
;
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
;
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
;
...
...
@@ -58,14 +60,14 @@ public:
// Override this as a work around for __grid_constant__ parameter
void
AddFunction
(
const
GlobalVar
&
gvar
,
const
PrimFunc
&
f
);
void
PrintFunctionSignature
(
const
String
&
function_name
,
const
PrimFunc
&
func
,
std
::
ostream
&
os
);
void
PrintFunctionSignature
(
const
ffi
::
String
&
function_name
,
const
PrimFunc
&
func
,
std
::
ostream
&
os
);
protected:
virtual
std
::
string
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
final
;
void
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
void
PrintCallExtern
(
Type
ret_type
,
ffi
::
String
global_symbol
,
const
ffi
::
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
private:
...
...
@@ -106,6 +108,16 @@ private:
bool
need_math_constants_h_
{
false
};
// whether need mma.h
bool
need_mma_h_
{
false
};
// whether need tl mma instruction header
bool
need_mma_instruction_h_
{
false
};
// whether need tl wgmma instruction header
bool
need_wgmma_instruction_h_
{
false
};
// whether need tl tcgen05mma instruction header
bool
need_tcgen05mma_instruction_h_
{
false
};
// whether need tl mma_sm70 instruction header
bool
need_mma_sm70_instruction_h_
{
false
};
// whether need tcgen_05 common header
bool
need_tcgen05_common_h_
{
false
};
// whether need cast_smem_ptr_to_int helper function
bool
need_cast_smem_ptr_to_int_
{
false
};
// whether need cooperative_groups.h
...
...
src/target/codegen_hip.cc
View file @
bbbf4207
...
...
@@ -929,7 +929,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{
"float32"
,
"float"
},
{
"float64"
,
"double"
},
{
"float16x4"
,
"float16x4"
},
{
"bfloat16x4"
,
"bfloat16x4"
},
{
"bfloat16x4"
,
"bfloat16x4
_vec
"
},
{
"float32x4"
,
"float32x4"
},
{
"float8_e4m3fnuzx4"
,
"fp8_e4_4_t"
},
{
"float8_e4m3fnuzx8"
,
"long"
},
...
...
@@ -1025,8 +1025,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
"A_ptr, B_ptr, C_ptr>, but got "
<<
op
->
args
.
size
();
auto
op_instance
=
Downcast
<
StringImm
>
(
op
->
args
[
0
]);
this
->
PrintCallExtern
(
GetType
(
GetRef
<
PrimExpr
>
(
op
)),
op_instance
->
value
,
op
->
args
,
true
,
os
);
this
->
PrintCallExtern
(
GetType
(
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
)),
op_instance
->
value
,
op
->
args
,
true
,
os
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
tl_gemm_sp
()))
{
LOG
(
FATAL
)
<<
"tl_gemm_sp is not supported on HIP"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
loop_break
()))
{
...
...
@@ -1375,7 +1375,7 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
ReserveKeywordsAsUnique
();
auto
global_symbol
=
f
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
ICHECK
(
global_symbol
.
defined
())
ICHECK
(
global_symbol
.
has_value
())
<<
"CodeGenC: Expect PrimFunc to have the global_symbol attribute"
;
bool
no_alias
=
f
->
HasNonzeroAttr
(
tir
::
attr
::
kNoAlias
);
...
...
src/target/codegen_hip.h
View file @
bbbf4207
...
...
@@ -56,8 +56,8 @@ public:
protected:
virtual
std
::
string
GetBufferRef
(
DataType
t
,
const
BufferNode
*
buffer
,
PrimExpr
index
)
final
;
void
PrintCallExtern
(
Type
ret_type
,
String
global_symbol
,
const
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
void
PrintCallExtern
(
Type
ret_type
,
ffi
::
String
global_symbol
,
const
ffi
::
Array
<
PrimExpr
>
&
args
,
bool
skip_first_arg
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
private:
...
...
src/target/codegen_webgpu.cc
deleted
100644 → 0
View file @
8f4628e0
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_webgpu.cc
*/
#include "codegen_webgpu.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>
#include <algorithm>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "arith/pattern_match.h"
#include "runtime/meta_data.h"
#include "runtime/thread_storage_scope.h"
#include "target/build_common.h"
namespace
tvm
{
namespace
codegen
{
// WebGPU Info
struct
WebGPUWorkGroupInfo
{
int
workgroup_size
[
3
]
=
{
1
,
1
,
1
};
// whether we have ref to block index z is used.
bool
has_block_index_z
{
false
};
// set of handles that have write access
std
::
unordered_set
<
Var
,
ObjectPtrHash
,
ObjectPtrEqual
>
write_access_set
;
};
class
WebGPUWorkgroupInfoCollector
:
public
StmtExprVisitor
{
public:
static
WebGPUWorkGroupInfo
Collect
(
const
Stmt
&
stmt
)
{
WebGPUWorkgroupInfoCollector
collector
;
collector
(
stmt
);
return
collector
.
info_
;
}
private:
void
VisitExpr_
(
const
VarNode
*
op
)
final
{
StmtExprVisitor
::
VisitExpr_
(
op
);
Var
buffer_var
=
GetRef
<
Var
>
(
op
);
if
(
buffer_var
.
dtype
().
is_handle
())
{
info_
.
write_access_set
.
insert
(
buffer_var
);
}
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
StmtExprVisitor
::
VisitStmt_
(
op
);
info_
.
write_access_set
.
insert
(
op
->
buffer
->
data
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
// record workgroup size
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
if
(
!
iv
->
thread_tag
.
empty
())
{
runtime
::
ThreadScope
ts
=
runtime
::
ThreadScope
::
Create
(
iv
->
thread_tag
);
if
(
ts
.
rank
==
1
)
{
ICHECK_GE
(
ts
.
dim_index
,
0
)
<<
"vthread should have been optimized out by here"
;
ICHECK_LT
(
ts
.
dim_index
,
3
);
auto
*
sizeptr
=
op
->
value
.
as
<
tir
::
IntImmNode
>
();
ICHECK
(
sizeptr
)
<<
"CodeGenTileLangWebGPU: only allows constant "
"thread group size "
<<
" get "
<<
op
->
value
;
info_
.
workgroup_size
[
ts
.
dim_index
]
=
static_cast
<
uint32_t
>
(
sizeptr
->
value
);
}
else
if
(
ts
.
rank
==
0
)
{
if
(
ts
.
dim_index
==
2
)
{
info_
.
has_block_index_z
=
true
;
}
}
}
}
// normal operation
StmtExprVisitor
::
VisitStmt_
(
op
);
}
WebGPUWorkGroupInfo
info_
;
};
std
::
string
CodeGenTileLangWebGPU
::
Finish
()
{
// Using f16 requires enable directive
if
(
enable_fp16_
)
{
header_stream
<<
"enable f16;
\n\n
"
;
}
// WebGPU WGSL doesn't support #include.
// We must explicitly include all the templates here.
return
header_stream
.
str
()
+
decl_stream
.
str
()
+
this
->
fwd_decl_stream
.
str
()
+
stream
.
str
();
}
void
CodeGenTileLangWebGPU
::
InitFuncState
(
const
PrimFunc
&
f
)
{
CodeGenC
::
InitFuncState
(
f
);
// analyze the data;
for
(
Var
arg
:
f
->
params
)
{
if
(
arg
.
dtype
().
is_handle
())
{
alloc_storage_scope_
[
arg
.
get
()]
=
"global"
;
}
}
}
CodeGenTileLangWebGPU
::
CodeGenTileLangWebGPU
(
Target
target
)
:
target_
(
target
)
{}
runtime
::
FunctionInfo
CodeGenTileLangWebGPU
::
AddFunction
(
const
PrimFunc
&
f
,
bool
skip_readonly_decl
)
{
// clear previous generated state.
this
->
InitFuncState
(
f
);
// reserve keywords
name_supply_
->
ReserveName
(
"var"
);
name_supply_
->
ReserveName
(
"let"
);
name_supply_
->
ReserveName
(
"const"
);
// skip the first underscore, so SSA variable starts from
name_supply_
->
FreshName
(
"v_"
);
// Setup the thread group info.
ICHECK_EQ
(
name_supply_
->
FreshName
(
"threadIdx"
),
"threadIdx"
);
ICHECK_EQ
(
name_supply_
->
FreshName
(
"blockIdx"
),
"blockIdx"
);
ICHECK_EQ
(
name_supply_
->
FreshName
(
"gridDim"
),
"gridDim"
);
// add to alloc buffer type.
auto
global_symbol
=
f
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
ICHECK
(
global_symbol
.
defined
())
<<
"CodeGenTileLangWebGPU: Expect PrimFunc "
"to have the global_symbol attribute"
;
header_stream
<<
"//----------------------------------------
\n
"
<<
"// Function: "
<<
global_symbol
.
value
()
<<
"
\n
"
<<
"//----------------------------------------
\n
"
;
runtime
::
FunctionInfo
func_info
;
func_info
.
name
=
global_symbol
.
value
();
WebGPUWorkGroupInfo
info
=
WebGPUWorkgroupInfoCollector
::
Collect
(
f
->
body
);
std
::
vector
<
Var
>
pod_args
;
int
num_buffer
=
0
;
// add param_access modes info to launch params
std
::
ostringstream
os_param_access
;
os_param_access
<<
"paramWriteAccess:["
;
// setup buffer argumemts
for
(
Var
arg
:
f
->
params
)
{
DataType
t
=
arg
.
dtype
();
func_info
.
arg_types
.
push_back
(
t
);
if
(
t
.
is_handle
())
{
auto
*
ptr
=
arg
->
type_annotation
.
as
<
PointerTypeNode
>
();
ICHECK
(
ptr
)
<<
"All handles passed to the CodeGenTileLangWebGPU must "
"have a type_annotation as a "
"PointerType, "
<<
"and must point to a PrimType"
;
auto
*
prim
=
ptr
->
element_type
.
as
<
PrimTypeNode
>
();
ICHECK
(
prim
)
<<
"All handles passed to the CodeGenTileLangWebGPU must "
"have a type_annotation as a "
"PointerType, "
<<
"and must point to a PrimType"
;
DataType
value_storage_type
=
prim
->
dtype
;
if
(
value_storage_type
==
DataType
::
Bool
())
{
// We need a physically addressable buffer type to support boolean
// tensors. The loaded byte is cast to bool inside the LoadNode visitor
// below.
value_storage_type
=
boolean_storage_type_
.
with_lanes
(
value_storage_type
.
lanes
());
}
std
::
string
vid
=
AllocVarID
(
arg
.
get
());
std
::
string
access_mode
;
if
(
num_buffer
!=
0
)
{
os_param_access
<<
","
;
}
if
(
skip_readonly_decl
||
info
.
write_access_set
.
count
(
arg
))
{
access_mode
=
"read_write"
;
os_param_access
<<
"1"
;
}
else
{
access_mode
=
"read"
;
os_param_access
<<
"0"
;
}
// add extra access mode info to launch params
this
->
decl_stream
<<
"@group(0) @binding("
<<
num_buffer
++
<<
") "
<<
"var<storage, "
<<
access_mode
<<
"> "
<<
vid
<<
" : array<"
;
this
->
PrintType
(
value_storage_type
,
this
->
decl_stream
);
this
->
decl_stream
<<
">;
\n
"
;
}
else
{
pod_args
.
push_back
(
arg
);
}
}
// Store all pod arguments in a single buffer of int32
// do bitcast to change to other data types
// always pass gridDimX in to get around of the 65535 gridDim
// restrictions in some platforms
std
::
string
type_pod_args
=
name_supply_
->
FreshName
(
"PODArgs"
);
std
::
string
val_pod_args
=
name_supply_
->
FreshName
(
"podArgs"
);
std
::
string
packGridDimX
=
name_supply_
->
FreshName
(
"packGridDimX"
);
this
->
decl_stream
<<
"
\n
struct "
<<
type_pod_args
<<
" {
\n
"
;
for
(
size_t
i
=
0
;
i
<
pod_args
.
size
();
++
i
)
{
const
Var
&
v
=
pod_args
[
i
];
ICHECK
(
!
v
.
dtype
().
is_handle
());
std
::
string
vid
=
AllocVarID
(
v
.
get
());
if
(
v
.
dtype
()
==
DataType
::
Int
(
32
))
{
this
->
decl_stream
<<
" "
<<
vid
<<
": i32"
;
}
else
if
(
v
.
dtype
()
==
DataType
::
UInt
(
32
))
{
this
->
decl_stream
<<
" "
<<
vid
<<
": u32"
;
}
else
if
(
v
.
dtype
()
==
DataType
::
Float
(
32
))
{
this
->
decl_stream
<<
" "
<<
vid
<<
": f32"
;
}
else
{
LOG
(
FATAL
)
<<
"Do not support pod argument type "
<<
v
.
dtype
();
}
this
->
decl_stream
<<
",
\n
"
;
// value ref
std
::
ostringstream
vref
;
vref
<<
val_pod_args
<<
"."
<<
vid
;
var_idmap_
[
v
.
get
()]
=
vref
.
str
();
}
this
->
decl_stream
<<
" "
<<
packGridDimX
<<
": u32
\n
}
\n
"
;
this
->
decl_stream
<<
"@group(0) @binding("
<<
num_buffer
++
<<
") "
<<
"var<uniform> "
<<
val_pod_args
<<
" : "
<<
type_pod_args
<<
";
\n\n
"
;
// setup thread tags and param access in launch param tags;
if
(
auto
opt
=
f
->
GetAttr
<
Array
<
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
for
(
const
auto
&
thread_tag
:
opt
.
value
())
{
func_info
.
launch_param_tags
.
push_back
(
thread_tag
);
}
}
os_param_access
<<
"]"
;
func_info
.
launch_param_tags
.
push_back
(
os_param_access
.
str
());
ICHECK
(
!
info
.
has_block_index_z
)
<<
"blockIdx.z is not supported in WebGPU to "
"accommodate large blockIdx.x"
;
// annotate workgroup
this
->
stream
<<
"@compute @workgroup_size("
<<
info
.
workgroup_size
[
0
]
<<
", "
<<
info
.
workgroup_size
[
1
]
<<
", "
<<
info
.
workgroup_size
[
2
]
<<
")
\n
"
;
// add to alloc buffer type.
// Function header.
this
->
stream
<<
"fn "
<<
func_info
.
name
<<
"(
\n
"
<<
" @builtin(workgroup_id) blockIdx : vec3<u32>,
\n
"
<<
" @builtin(num_workgroups) gridDim : vec3<u32>,
\n
"
<<
" @builtin(local_invocation_id) threadIdx : vec3<u32>
\n
"
<<
") {
\n
"
;
// skip out of bound grids
this
->
stream
<<
" if (blockIdx.z * gridDim.x + blockIdx.x > "
// NOLINT(*)
<<
val_pod_args
<<
"."
<<
packGridDimX
<<
") { return; }
\n
"
;
// the function scope.
int
func_scope
=
this
->
BeginScope
();
this
->
PrintStmt
(
f
->
body
);
this
->
EndScope
(
func_scope
);
this
->
PrintIndent
();
this
->
stream
<<
"}
\n\n
"
;
return
func_info
;
}
void
CodeGenTileLangWebGPU
::
BindThreadIndex
(
const
IterVar
&
iv
)
{
ICHECK
(
!
var_idmap_
.
count
(
iv
->
var
.
get
()));
std
::
ostringstream
os
;
PrintType
(
iv
->
var
.
dtype
(),
os
);
if
(
iv
->
thread_tag
==
"blockIdx.x"
)
{
// WebGPU have restriction to limit the maximum size of blockId.x to be
// 65535 We allow runtime to spread the load out to blockIdx.z so it can be
// a large number.
os
<<
"(blockIdx.z * gridDim.x + blockIdx.x)"
;
std
::
string
tidx
=
os
.
str
();
std
::
string
aggregated_bidx
=
SSAGetID
(
os
.
str
(),
iv
->
var
.
dtype
());
var_idmap_
[
iv
->
var
.
get
()]
=
aggregated_bidx
;
}
else
{
os
<<
"("
<<
iv
->
thread_tag
<<
")"
;
std
::
string
tidx
=
os
.
str
();
this
->
MarkConst
(
tidx
);
var_idmap_
[
iv
->
var
.
get
()]
=
tidx
;
}
}
void
CodeGenTileLangWebGPU
::
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
{
// NOLINT(*)
int
lanes
=
t
.
lanes
();
if
(
t
.
is_handle
())
{
LOG
(
FATAL
)
<<
"Cannot print handle type in WebGPU"
;
}
if
(
t
.
is_void
())
{
os
<<
"void"
;
return
;
}
if
(
t
==
DataType
::
Bool
())
{
os
<<
"bool"
;
return
;
}
if
(
lanes
!=
1
)
{
// ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenTileLangWebGPU: only allows
// vector with lanes in {2, 3, 4} " << " while lanes is " << lanes;
os
<<
"vec"
<<
lanes
<<
"<"
;
}
if
(
t
.
is_float
())
{
ICHECK
(
t
.
bits
()
==
16
||
t
.
bits
()
==
32
)
<<
"CodeGenTileLangWebGPU: only support f16 or f32"
;
if
(
t
.
bits
()
==
16
)
{
// Using f16 requires enable directive
enable_fp16_
=
true
;
}
os
<<
"f"
<<
t
.
bits
();
}
else
if
(
t
.
is_uint
())
{
ICHECK
(
t
.
bits
()
!=
64
)
<<
"CodeGenTileLangWebGPU: do not support u64"
;
os
<<
"u"
<<
t
.
bits
();
}
else
if
(
t
.
is_int
())
{
ICHECK
(
t
.
bits
()
!=
64
)
<<
"CodeGenTileLangWebGPU: do not support i64"
;
os
<<
"i"
<<
t
.
bits
();
}
else
{
LOG
(
FATAL
)
<<
"CodeGenTileLangWebGPU: Cannot convert type "
<<
t
<<
" to WebGPU type"
;
}
if
(
lanes
!=
1
)
{
os
<<
">"
;
}
}
void
CodeGenTileLangWebGPU
::
PrintStorageSync
(
const
CallNode
*
op
)
{
const
std
::
string
&
sync
=
op
->
args
[
0
].
as
<
StringImmNode
>
()
->
value
;
if
(
sync
==
"warp"
)
{
this
->
PrintIndent
();
this
->
stream
<<
"workgroupBarrier();
\n
"
;
}
else
if
(
sync
==
"shared"
)
{
this
->
PrintIndent
();
this
->
stream
<<
"workgroupBarrier();
\n
"
;
}
else
if
(
sync
==
"global"
)
{
LOG
(
FATAL
)
<<
"global barrier not supported"
;
}
}
void
CodeGenTileLangWebGPU
::
PrintSSAAssign
(
const
std
::
string
&
target
,
const
std
::
string
&
src
,
DataType
type
)
{
stream
<<
"let "
<<
target
<<
" : "
;
PrintType
(
type
,
stream
);
stream
<<
" = "
<<
src
<<
";
\n
"
;
}
void
CodeGenTileLangWebGPU
::
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
std
::
string
v
=
PrintExpr
(
op
->
value
);
int
lanes
=
op
->
dtype
.
lanes
();
PrintType
(
op
->
dtype
,
os
);
os
<<
"("
;
for
(
int
i
=
0
;
i
<
lanes
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
os
<<
v
;
}
os
<<
')'
;
}
PrimExpr
CodeGenTileLangWebGPU
::
EnforceU32
(
PrimExpr
value
)
{
return
cast
(
DataType
::
UInt
(
32
,
value
.
dtype
().
lanes
()),
value
);
}
void
CodeGenTileLangWebGPU
::
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
if
(
op
->
op
.
same_as
(
builtin
::
reinterpret
()))
{
// generate bitcast<TYPE>(ARG)
os
<<
"bitcast<"
;
this
->
PrintType
(
op
->
dtype
,
os
);
os
<<
">("
;
this
->
PrintExpr
(
op
->
args
[
0
],
os
);
os
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
shift_right
()))
{
os
<<
'('
;
this
->
PrintExpr
(
op
->
args
[
0
],
os
);
os
<<
">>"
;
// WebGPU requires shift bits to be u32.
this
->
PrintExpr
(
EnforceU32
(
op
->
args
[
1
]),
os
);
os
<<
')'
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
shift_left
()))
{
os
<<
'('
;
this
->
PrintExpr
(
op
->
args
[
0
],
os
);
os
<<
"<<"
;
// WebGPU requires shift bits to be u32.
this
->
PrintExpr
(
EnforceU32
(
op
->
args
[
1
]),
os
);
os
<<
')'
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
if_then_else
()))
{
// conditional that skips eval if cond evals to false
std
::
string
result
=
name_supply_
->
FreshName
(
"condval"
);
std
::
string
cond
=
PrintExpr
(
op
->
args
[
0
]);
this
->
PrintIndent
();
this
->
stream
<<
"var "
<<
result
<<
" : "
;
PrintType
(
op
->
dtype
,
this
->
stream
);
this
->
stream
<<
";
\n
"
;
this
->
PrintIndent
();
this
->
stream
<<
"if ("
<<
cond
<<
") {
\n
"
;
{
int
then_scope
=
this
->
BeginScope
();
std
::
string
true_val
=
PrintExpr
(
op
->
args
[
1
]);
this
->
PrintIndent
();
this
->
stream
<<
result
<<
" = "
<<
true_val
<<
";
\n
} else {
\n
"
;
this
->
EndScope
(
then_scope
);
}
{
int
else_scope
=
this
->
BeginScope
();
std
::
string
false_val
=
PrintExpr
(
op
->
args
[
2
]);
this
->
PrintIndent
();
this
->
stream
<<
result
<<
" = "
<<
false_val
<<
";
\n
}
\n
"
;
this
->
EndScope
(
else_scope
);
}
os
<<
result
;
}
else
{
CodeGenC
::
VisitExpr_
(
op
,
os
);
}
}
void
CodeGenTileLangWebGPU
::
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintType
(
op
->
dtype
,
os
);
os
<<
"("
<<
PrintExpr
(
op
->
value
)
<<
")"
;
}
void
CodeGenTileLangWebGPU
::
VisitExpr_
(
const
SelectNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
os
<<
"select("
<<
PrintExpr
(
op
->
false_value
)
<<
", "
<<
PrintExpr
(
op
->
true_value
)
<<
", "
<<
PrintExpr
(
op
->
condition
)
<<
")"
;
}
void
CodeGenTileLangWebGPU
::
VisitExpr_
(
const
IntImmNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
if
(
op
->
dtype
.
bits
()
==
32
)
{
std
::
ostringstream
temp
;
if
(
op
->
dtype
.
is_int
())
{
temp
<<
op
->
value
<<
"i"
;
}
else
{
ICHECK
(
op
->
dtype
.
is_uint
());
temp
<<
op
->
value
<<
"u"
;
}
this
->
MarkConst
(
temp
.
str
());
os
<<
temp
.
str
();
}
else
{
this
->
PrintType
(
op
->
dtype
,
os
);
os
<<
"("
<<
op
->
value
<<
")"
;
}
}
void
CodeGenTileLangWebGPU
::
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
std
::
ostringstream
temp
;
temp
<<
std
::
scientific
<<
op
->
value
;
if
(
op
->
dtype
.
bits
()
==
32
)
{
temp
<<
'f'
;
}
else
if
(
op
->
dtype
.
bits
()
==
16
)
{
// Using f16 requires enable directive
enable_fp16_
=
true
;
temp
<<
'h'
;
}
else
{
LOG
(
FATAL
)
<<
"Unsupported floating point bits "
<<
op
->
dtype
.
bits
();
}
MarkConst
(
temp
.
str
());
os
<<
temp
.
str
();
}
void
CodeGenTileLangWebGPU
::
VisitExpr_
(
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
// NOTE: direct impl of load/store for correctness
// Each printing stmt must stand on their own after all preprocessing steps
// to ensure correctness in the case of nested-expression
// do not try to lift common printings from each case
ICHECK_EQ
(
op
->
indices
.
size
(),
1
)
<<
"Load from non-flat memory not supported."
;
DataType
value_dtype
=
op
->
dtype
;
PrimExpr
index
=
op
->
indices
[
0
];
Var
buffer_var
=
op
->
buffer
->
data
;
DataType
element_dtype
=
op
->
buffer
->
dtype
;
int
lanes
=
op
->
dtype
.
lanes
();
std
::
string
buffer_vid
=
GetVarID
(
buffer_var
.
get
());
if
(
value_dtype
.
lanes
()
==
element_dtype
.
lanes
())
{
// Direct buffer loading
// Special handle bool loading
if
(
value_dtype
==
DataType
::
Bool
())
{
this
->
PrintType
(
value_dtype
,
os
);
os
<<
"("
;
}
else
{
ICHECK
(
value_dtype
==
element_dtype
);
}
ICHECK_EQ
(
index
.
dtype
().
lanes
(),
1
);
os
<<
buffer_vid
<<
"["
<<
this
->
PrintExpr
(
index
)
<<
"]"
;
// Special handle bool loading
if
(
value_dtype
==
DataType
::
Bool
())
{
os
<<
")"
;
}
}
else
{
// Vector load from scalar buffer
ICHECK_EQ
(
element_dtype
.
lanes
(),
1
)
<<
"Can only vector load scalar array"
;
ICHECK
(
value_dtype
.
element_of
()
==
element_dtype
)
<<
"WebGPU vector loading requires base type to match"
;
arith
::
PVar
<
PrimExpr
>
base
;
if
(
arith
::
ramp
(
base
,
1
,
op
->
dtype
.
lanes
()).
Match
(
index
))
{
// vec3<f32>(buf[base + 0], buf[base + 1], buf[base + 2]);
std
::
string
base_vid
=
SSAGetID
(
PrintExpr
(
base
.
Eval
()),
base
.
Eval
().
dtype
());
PrintType
(
element_dtype
.
with_lanes
(
value_dtype
.
lanes
()),
os
);
os
<<
"("
;
for
(
int
i
=
0
;
i
<
lanes
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
os
<<
buffer_vid
<<
"["
<<
base_vid
<<
" + "
<<
i
<<
"]"
;
}
os
<<
")"
;
}
else
{
// vec3<f32>(buf[index[0]], buf[index[1]], buf[index[2]]);
std
::
string
index_vid
=
SSAGetID
(
PrintExpr
(
index
),
index
.
dtype
());
PrintType
(
element_dtype
.
with_lanes
(
value_dtype
.
lanes
()),
os
);
os
<<
"("
;
for
(
int
i
=
0
;
i
<
lanes
;
++
i
)
{
if
(
i
!=
0
)
os
<<
", "
;
os
<<
buffer_vid
<<
"["
<<
index_vid
<<
"["
<<
i
<<
"]]"
;
}
os
<<
")"
;
}
}
}
void
CodeGenTileLangWebGPU
::
VisitStmt_
(
const
LetStmtNode
*
op
)
{
// use ssa form.
if
(
print_ssa_form_
)
{
std
::
string
value
=
PrintExpr
(
op
->
value
);
ICHECK
(
!
var_idmap_
.
count
(
op
->
var
.
get
()));
var_idmap_
[
op
->
var
.
get
()]
=
value
;
}
else
{
PrintIndent
();
std
::
string
value
=
PrintExpr
(
op
->
value
);
this
->
stream
<<
"let "
<<
AllocVarID
(
op
->
var
.
get
())
<<
" : "
;
PrintType
(
op
->
var
.
dtype
(),
this
->
stream
);
this
->
stream
<<
" = "
<<
value
<<
";
\n
"
;
}
PrintStmt
(
op
->
body
);
}
void
CodeGenTileLangWebGPU
::
VisitStmt_
(
const
BufferStoreNode
*
op
)
{
CHECK_EQ
(
op
->
indices
.
size
(),
1
)
<<
"Store to non-flat memory not supported."
;
DataType
value_dtype
=
op
->
value
.
dtype
();
DataType
element_dtype
=
op
->
buffer
->
dtype
;
PrimExpr
index
=
op
->
indices
[
0
];
Var
buffer_var
=
op
->
buffer
->
data
;
std
::
string
buffer_vid
=
GetVarID
(
buffer_var
.
get
());
if
(
value_dtype
.
lanes
()
==
element_dtype
.
lanes
())
{
// must execute print expr first
// so we won't have recursive append to stream
std
::
string
index_vid
=
PrintExpr
(
index
);
std
::
string
value_vid
=
PrintExpr
(
op
->
value
);
// now print the assignment line.
this
->
PrintIndent
();
stream
<<
buffer_vid
<<
"["
<<
index_vid
<<
"] = "
;
// special explicit conversion of bool
if
(
value_dtype
==
DataType
::
Bool
())
{
PrintType
(
element_dtype
,
stream
);
stream
<<
"("
;
}
else
{
ICHECK
(
value_dtype
==
element_dtype
);
}
stream
<<
value_vid
;
// Special handle bool store
if
(
value_dtype
==
DataType
::
Bool
())
{
stream
<<
")"
;
}
stream
<<
";
\n
"
;
}
else
{
// Vector store into scalar buffer
ICHECK_EQ
(
element_dtype
.
lanes
(),
1
)
<<
"Can only vector load scalar array"
;
ICHECK
(
value_dtype
.
element_of
()
==
element_dtype
)
<<
"WebGPU vector stire requires base type to match"
;
std
::
string
value_vid
=
PrintExpr
(
op
->
value
);
arith
::
PVar
<
PrimExpr
>
base
;
if
(
arith
::
ramp
(
base
,
1
,
value_dtype
.
lanes
()).
Match
(
index
))
{
// buf[base + 0] = value[0]
// buf[base + 1] = value[1]
std
::
string
base_vid
=
SSAGetID
(
PrintExpr
(
base
.
Eval
()),
base
.
Eval
().
dtype
());
for
(
int
i
=
0
;
i
<
value_dtype
.
lanes
();
++
i
)
{
this
->
PrintIndent
();
stream
<<
buffer_vid
<<
"["
<<
base_vid
<<
" + "
<<
i
<<
"] = "
<<
value_vid
<<
"["
<<
i
<<
"];
\n
"
;
}
}
else
{
// buf[index[0]] = value[0]
// buf[index[1]] = value[1]
std
::
string
index_vid
=
SSAGetID
(
PrintExpr
(
index
),
index
.
dtype
());
for
(
int
i
=
0
;
i
<
value_dtype
.
lanes
();
++
i
)
{
this
->
PrintIndent
();
stream
<<
buffer_vid
<<
"["
<<
index_vid
<<
"["
<<
i
<<
"]] = "
<<
value_vid
<<
"["
<<
i
<<
"];
\n
"
;
}
}
}
}
void
CodeGenTileLangWebGPU
::
VisitStmt_
(
const
AllocateNode
*
op
)
{
ICHECK
(
!
is_zero
(
op
->
condition
));
std
::
string
vid
=
AllocVarID
(
op
->
buffer_var
.
get
());
size_t
constant_size
=
op
->
ConstantAllocationSize
();
ICHECK_GT
(
constant_size
,
0
)
<<
"Can only handle constant size stack allocation for now"
;
auto
storage_scope
=
runtime
::
StorageScope
::
Create
(
GetPtrStorageScope
(
op
->
buffer_var
));
if
(
storage_scope
.
rank
==
runtime
::
StorageRank
::
kShared
)
{
this
->
decl_stream
<<
"var<workgroup> "
<<
vid
<<
" : array<"
;
PrintType
(
op
->
dtype
,
this
->
decl_stream
);
this
->
decl_stream
<<
", "
<<
constant_size
<<
">;
\n
"
;
}
else
if
(
storage_scope
.
rank
==
runtime
::
StorageRank
::
kLocal
)
{
// TODO(Charlie): These code would cause non-uniformity as it introduces
// variables in module scope rather than function scope; but it was included
// for some unknown reasons; kept for now. this->decl_stream <<
// "var<private> " << vid << " : array<"; PrintType(op->dtype,
// this->decl_stream); this->decl_stream << ", " << constant_size << ">;\n";
this
->
PrintIndent
();
this
->
stream
<<
"var "
<<
vid
<<
" : array<"
;
PrintType
(
op
->
dtype
,
this
->
stream
);
this
->
stream
<<
", "
<<
constant_size
<<
">;
\n
"
;
}
else
{
LOG
(
FATAL
)
<<
"WebGPU: Do not support storage scope: "
<<
storage_scope
.
to_string
();
}
this
->
PrintStmt
(
op
->
body
);
}
void
CodeGenTileLangWebGPU
::
VisitStmt_
(
const
ForNode
*
op
)
{
std
::
string
extent
=
PrintExpr
(
op
->
extent
);
std
::
string
vid
=
AllocVarID
(
op
->
loop_var
.
get
());
ICHECK
(
is_zero
(
op
->
min
));
PrintIndent
();
stream
<<
"for (var "
<<
vid
<<
" : "
;
PrintType
(
op
->
loop_var
.
dtype
(),
stream
);
stream
<<
" = 0; "
<<
vid
<<
" < "
<<
extent
<<
"; "
<<
vid
<<
"++) {
\n
"
;
int
for_scope
=
BeginScope
();
PrintStmt
(
op
->
body
);
this
->
EndScope
(
for_scope
);
PrintIndent
();
stream
<<
"}
\n
"
;
}
void
CodeGenTileLangWebGPU
::
VisitStmt_
(
const
AssertStmtNode
*
op
)
{
// skip assert
PrintStmt
(
op
->
body
);
}
void
CodeGenTileLangWebGPU
::
VisitStmt_
(
const
AllocateConstNode
*
op
)
{
LOG
(
FATAL
)
<<
"WebGPU: do not support alloc const"
;
}
void
CodeGenTileLangWebGPU
::
VisitStmt_
(
const
WhileNode
*
op
)
{
PrintIndent
();
stream
<<
"while (true) {
\n
"
;
int
while_scope
=
BeginScope
();
std
::
string
cond
=
PrintExpr
(
op
->
condition
);
PrintIndent
();
stream
<<
"if (!("
<<
cond
<<
")) { break; }
\n
"
;
PrintStmt
(
op
->
body
);
this
->
EndScope
(
while_scope
);
PrintIndent
();
stream
<<
"}
\n
"
;
}
//-------------------------------------------------
// WebGPUSourceModule to enable export
//-------------------------------------------------
class
WebGPUSourceModuleNode
final
:
public
runtime
::
ModuleNode
{
public:
explicit
WebGPUSourceModuleNode
(
std
::
unordered_map
<
std
::
string
,
std
::
string
>
smap
,
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
fmap
)
:
smap_
(
smap
),
fmap_
(
fmap
)
{}
const
char
*
type_key
()
const
final
{
return
"webgpu"
;
}
/*! \brief Get the property of the runtime module .*/
int
GetPropertyMask
()
const
final
{
return
runtime
::
ModulePropertyMask
::
kBinarySerializable
;
}
ffi
::
Function
GetFunction
(
const
String
&
name
,
const
ObjectPtr
<
Object
>
&
sptr_to_self
)
final
{
LOG
(
FATAL
)
<<
"WebGPUSourceModule is not directly runnable, export and run "
"through tvmjs"
;
return
ffi
::
Function
(
nullptr
);
}
void
SaveToBinary
(
dmlc
::
Stream
*
stream
)
final
{
stream
->
Write
(
fmap_
);
stream
->
Write
(
smap_
);
}
String
GetSource
(
const
String
&
format
)
final
{
if
(
format
==
"func_info"
)
{
std
::
ostringstream
stream
;
dmlc
::
JSONWriter
(
&
stream
).
Write
(
fmap_
);
return
stream
.
str
();
}
else
{
std
::
ostringstream
os
;
for
(
const
auto
&
kv
:
smap_
)
{
os
<<
kv
.
second
;
}
return
os
.
str
();
}
}
private:
// function shader code table.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
smap_
;
// function information table.
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
fmap_
;
};
//-------------------------------------------------
// Build logic.
//-------------------------------------------------
runtime
::
Module
BuildTileLangWebGPU
(
IRModule
mod
,
Target
target
)
{
mod
=
tir
::
transform
::
PointerValueTypeRewrite
()(
std
::
move
(
mod
));
bool
output_ssa
=
false
;
bool
skip_readonly_decl
=
false
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
smap
;
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
fmap
;
// narrow all i64 to i32
mod
=
tir
::
transform
::
ForceNarrowIndexToInt32
()(
std
::
move
(
mod
));
for
(
auto
kv
:
mod
->
functions
)
{
CodeGenTileLangWebGPU
cg
(
target
);
ICHECK
(
kv
.
second
->
IsInstance
<
PrimFuncNode
>
())
<<
"CodeGenTileLangWebGPU: Can only take PrimFunc"
;
auto
f
=
Downcast
<
PrimFunc
>
(
kv
.
second
);
auto
calling_conv
=
f
->
GetAttr
<
Integer
>
(
tvm
::
attr
::
kCallingConv
);
ICHECK
(
calling_conv
==
CallingConv
::
kDeviceKernelLaunch
)
<<
"CodeGenTileLangWebGPU: expect calling_conv equals "
"CallingConv::kDeviceKernelLaunch"
;
auto
global_symbol
=
f
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
ICHECK
(
global_symbol
.
defined
())
<<
"CodeGenTileLangWebGPU: Expect PrimFunc "
"to have the global_symbol attribute"
;
std
::
string
f_name
=
global_symbol
.
value
();
cg
.
Init
(
output_ssa
);
fmap
[
f_name
]
=
cg
.
AddFunction
(
f
,
skip_readonly_decl
);
std
::
string
code
=
cg
.
Finish
();
smap
[
f_name
]
=
code
;
}
auto
n
=
make_object
<
WebGPUSourceModuleNode
>
(
smap
,
fmap
);
return
runtime
::
Module
(
n
);
}
TVM_FFI_STATIC_INIT_BLOCK
({
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"target.build.tilelang_webgpu"
,
[](
IRModule
mod
,
Target
target
)
{
return
BuildTileLangWebGPU
(
mod
,
target
);
});
});
}
// namespace codegen
}
// namespace tvm
src/target/codegen_webgpu.h
deleted
100644 → 0
View file @
8f4628e0
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_webgpu.h
* \brief Generate WebGPU shaders in WGSL.
*
* This module generates WGSL shading language.
* See https://www.w3.org/TR/WGSL/ for the language reference.
*/
#ifndef TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#define TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#include <tvm/target/codegen.h>
#include <string>
#include "target/source/codegen_c.h"
namespace
tvm
{
namespace
codegen
{
/*!
* \brief WebGPU code generator.
*
* Note WGSL have a different syntax from normal C.
* We only leverage the C for expression generation and
* write most of the language generations.
*/
class
CodeGenTileLangWebGPU
final
:
public
CodeGenC
{
public:
explicit
CodeGenTileLangWebGPU
(
Target
target
);
// overrides
std
::
string
Finish
()
final
;
using
CodeGenC
::
AddFunction
;
runtime
::
FunctionInfo
AddFunction
(
const
PrimFunc
&
f
,
bool
skip_readonly_decl
);
// NOLINT(*)
void
InitFuncState
(
const
PrimFunc
&
f
)
final
;
void
PrintStorageSync
(
const
CallNode
*
op
)
final
;
// NOLINT(*)
void
PrintType
(
DataType
t
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
BindThreadIndex
(
const
IterVar
&
iv
)
final
;
// NOLINT(*)
// assignment printing
void
PrintSSAAssign
(
const
std
::
string
&
target
,
const
std
::
string
&
src
,
DataType
type
)
final
;
// overload visitor
void
VisitExpr_
(
const
BroadcastNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
BufferLoadNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
CastNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
SelectNode
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
FloatImmNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
VisitExpr_
(
const
IntImmNode
*
op
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
// stmt printing
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
;
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
;
void
VisitStmt_
(
const
ForNode
*
op
)
final
;
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
;
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
;
void
VisitStmt_
(
const
AllocateConstNode
*
op
)
final
;
void
VisitStmt_
(
const
WhileNode
*
op
)
final
;
private:
/*!
* \brief Enforce value to be U32.
*/
static
PrimExpr
EnforceU32
(
PrimExpr
value
);
/*!
* \brief Storage type of bool values.
*/
DataType
boolean_storage_type_
{
DataType
::
Int
(
8
)};
// whether enable fp16
bool
enable_fp16_
{
false
};
/*! \brief the header stream for function label and enable directive if any,
* goes before any other declaration */
std
::
ostringstream
header_stream
;
Target
target_
;
};
}
// namespace codegen
}
// namespace tvm
#endif // TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
src/target/intrin_rule_cuda.cc
View file @
bbbf4207
...
...
@@ -5,6 +5,7 @@
#include <tvm/tir/builtin.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
#include "target/intrin_rule.h"
namespace
tvm
{
...
...
src/target/intrin_rule_hip.cc
View file @
bbbf4207
...
...
@@ -5,6 +5,7 @@
#include <tvm/tir/builtin.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
#include "target/intrin_rule.h"
namespace
tvm
{
...
...
src/target/ptx.cc
View file @
bbbf4207
...
...
@@ -74,9 +74,9 @@ DataType DTypeFromString(const std::string str) {
return
DataType
::
kInt64
;
}
else
if
(
str
==
"uint64"
||
str
==
".u64"
)
{
return
DataType
::
kUInt64
;
}
else
if
(
str
==
"e4m3"
||
str
==
".e4m3"
)
{
}
else
if
(
str
==
"float8_e4m3"
||
str
==
"e4m3"
||
str
==
".e4m3"
)
{
return
DataType
::
kFloat8_e4m3
;
}
else
if
(
str
==
"e5m2"
||
str
==
".e5m2"
)
{
}
else
if
(
str
==
"float8_e5m2"
||
str
==
"e5m2"
||
str
==
".e5m2"
)
{
return
DataType
::
kFloat8_e5m2
;
}
else
if
(
str
==
"float16"
||
str
==
"fp16"
||
str
==
".f16"
)
{
return
DataType
::
kFloat16
;
...
...
@@ -1529,5 +1529,20 @@ std::string PrintWaitBarrierAsm(const std::string &barrier) {
return
predicated_asm_code
;
}
std
::
string
GetMMARegisterType
(
const
ptx
::
DataType
&
dtype
)
{
switch
(
dtype
)
{
case
ptx
::
DataType
::
kInt32
:
return
"unsigned"
;
case
ptx
::
DataType
::
kUInt32
:
return
"unsigned"
;
case
ptx
::
DataType
::
kFloat32
:
return
"float"
;
case
ptx
::
DataType
::
kFloat64
:
return
"double"
;
default:
return
"unsigned"
;
}
}
}
// namespace codegen
}
// namespace tvm::tl
src/target/ptx.h
View file @
bbbf4207
...
...
@@ -269,6 +269,11 @@ std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier,
*/
std
::
string
PrintWaitBarrierAsm
(
const
std
::
string
&
barrier
);
/*!
* \brief Return the register-level C++ type used by MMA fragments.
*/
std
::
string
GetMMARegisterType
(
const
ptx
::
DataType
&
dtype
);
}
// namespace codegen
}
// namespace tvm::tl
...
...
src/target/rt_mod_cpp.cc
View file @
bbbf4207
#include "codegen_cpp.h"
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/reflection/registry.h>
#include "../support/ffi_aliases.h"
namespace
tvm
{
namespace
codegen
{
runtime
::
Module
BuildCPPHost
(
IRModule
mod
,
Target
target
)
{
ffi
::
Module
BuildCPPHost
(
IRModule
mod
,
Target
target
)
{
bool
output_ssa
=
false
;
bool
emit_asserts
=
false
;
bool
emit_fwd_func_decl
=
true
;
...
...
@@ -67,10 +70,10 @@ runtime::Module BuildCPPHost(IRModule mod, Target target) {
return
CSourceModuleCreate
(
code
,
"c"
,
cg
.
GetFunctionNames
());
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"target.build.tilelang_cpp"
,
BuildCPPHost
);
}
);
}
}
// namespace codegen
}
// namespace tvm
src/target/rt_mod_cuda.cc
View file @
bbbf4207
...
...
@@ -26,18 +26,19 @@ ExtractFuncInfo(const IRModule &mod) {
}
info
.
arg_types
.
push_back
(
f
->
params
[
i
].
dtype
());
}
if
(
auto
opt
=
f
->
GetAttr
<
Array
<
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
if
(
auto
opt
=
f
->
GetAttr
<
ffi
::
Array
<
ffi
::
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
info
.
launch_param_tags
.
push_back
(
tag
);
}
}
auto
global_symbol
=
f
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
auto
global_symbol
=
f
->
GetAttr
<
ffi
::
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
fmap
[
static_cast
<
std
::
string
>
(
global_symbol
.
value
())]
=
info
;
}
return
fmap
;
}
runtime
::
Module
BuildTileLangCUDA
(
IRModule
mod
,
Target
target
)
{
ffi
::
Module
BuildTileLangCUDA
(
IRModule
mod
,
Target
target
)
{
bool
output_ssa
=
false
;
CodeGenTileLangCUDA
cg
;
cg
.
Init
(
output_ssa
);
...
...
@@ -70,7 +71,7 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
return
runtime
::
CUDAModuleCreate
(
ptx
,
fmt
,
ExtractFuncInfo
(
mod
),
code
);
}
runtime
::
Module
BuildTileLangCUDAWithoutCompile
(
IRModule
mod
,
Target
target
)
{
ffi
::
Module
BuildTileLangCUDAWithoutCompile
(
IRModule
mod
,
Target
target
)
{
bool
output_ssa
=
false
;
CodeGenTileLangCUDA
cg
;
cg
.
Init
(
output_ssa
);
...
...
@@ -93,13 +94,13 @@ runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
return
runtime
::
CUDAModuleCreate
(
"ptx"
,
"ptx"
,
ExtractFuncInfo
(
mod
),
code
);
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
()
.
def
(
"target.build.tilelang_cuda"
,
BuildTileLangCUDA
)
.
def
(
"target.build.tilelang_cuda_without_compile"
,
BuildTileLangCUDAWithoutCompile
);
}
);
}
}
// namespace codegen
}
// namespace tvm
Prev
1
2
3
4
5
6
7
8
9
10
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment