Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
2c490782
"...git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "ecf337bab5c23708d80a4c537c6b49dbda6e23b2"
Commit
2c490782
authored
Oct 28, 2025
by
Lukinon
Committed by
qisan
Oct 28, 2025
Browse files
[Feature] Add support for Hygon DCU
parent
7d389a43
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1648 additions
and
10 deletions
+1648
-10
examples/gemm/example_gemm_intrinsics_dcu.py
examples/gemm/example_gemm_intrinsics_dcu.py
+190
-0
examples/minference/ops/vertical_slash_index.hip
examples/minference/ops/vertical_slash_index.hip
+161
-0
src/layout/gemm_layouts.cc
src/layout/gemm_layouts.cc
+18
-0
src/layout/layout.h
src/layout/layout.h
+3
-0
src/op/gemm.cc
src/op/gemm.cc
+8
-1
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+72
-6
src/target/intrin_rule_hip.cc
src/target/intrin_rule_hip.cc
+11
-1
src/target/utils.cc
src/target/utils.cc
+11
-0
src/target/utils.h
src/target/utils.h
+1
-0
src/tl_templates/dcu_hip/common.h
src/tl_templates/dcu_hip/common.h
+146
-0
src/tl_templates/dcu_hip/copy.h
src/tl_templates/dcu_hip/copy.h
+111
-0
src/tl_templates/dcu_hip/core.hpp
src/tl_templates/dcu_hip/core.hpp
+106
-0
src/tl_templates/dcu_hip/debug.h
src/tl_templates/dcu_hip/debug.h
+192
-0
src/tl_templates/dcu_hip/gemm.h
src/tl_templates/dcu_hip/gemm.h
+324
-0
src/tl_templates/dcu_hip/hip_fp8.h
src/tl_templates/dcu_hip/hip_fp8.h
+75
-0
src/tl_templates/dcu_hip/ldsm.h
src/tl_templates/dcu_hip/ldsm.h
+3
-0
src/tl_templates/dcu_hip/reduce.h
src/tl_templates/dcu_hip/reduce.h
+167
-0
src/tl_templates/dcu_hip/threadblock_swizzle.h
src/tl_templates/dcu_hip/threadblock_swizzle.h
+46
-0
tilelang/contrib/hipcc.py
tilelang/contrib/hipcc.py
+2
-1
tilelang/contrib/rocm.py
tilelang/contrib/rocm.py
+1
-1
No files found.
examples/gemm/example_gemm_intrinsics_dcu.py
0 → 100644
View file @
2c490782
from
tilelang
import
tvm
as
tvm
from
tvm
import
DataType
import
tilelang
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mmac_macro_generator
import
(
MatrixCoreIntrinEmitter
,)
from
tilelang.transform
import
simplify_prim_func
from
tilelang
import
disable_cache
disable_cache
()
def
make_swizzle_layout
(
shared_buf
):
dtype
=
shared_buf
.
dtype
shape
=
shared_buf
.
shape
can_swizzle
=
shape
[
-
1
]
*
DataType
(
dtype
).
bits
==
512
if
not
can_swizzle
:
return
T
.
Layout
(
shape
,
lambda
*
args
:
args
)
def
transform_func
(
i
,
j
):
new_warp_i
,
new_warp_j
=
get_swizzle_layout
(
i
,
j
,
shape
[
-
1
],
dtype
)
return
[
new_warp_i
,
new_warp_j
]
return
T
.
Layout
(
shape
,
transform_func
)
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
simplify_prim_func
def
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
):
assert
in_dtype
in
[
"float16"
,
"int8"
,
],
"Currently only float16 and int8 are supported"
assert
out_dtype
in
[
"float16"
,
"float32"
,
"int32"
,
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
if
out_dtype
==
"int32"
:
micro_size_k
=
32
# This is a debug config
block_row_warps
=
2
block_col_warps
=
2
warp_row_tiles
=
64
warp_col_tiles
=
64
# chunk = 32 if in_dtype == "float16" else 64
chunk
=
32
shared_scope
=
"shared.dyn"
# Pipeline Stage
stage
=
2
block_M
=
block_row_warps
*
warp_row_tiles
block_N
=
block_col_warps
*
warp_col_tiles
block_K
=
chunk
A_shape
=
(
M
,
K
)
B_shape
=
(
N
,
K
)
A_shared_shape
=
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
C_shared_shape
=
(
block_M
//
micro_size_x
,
block_N
//
micro_size_y
,
micro_size_x
,
micro_size_y
,
)
warp_size
=
64
threads
=
warp_size
*
(
block_row_warps
*
block_col_warps
)
local_size_a
=
(
micro_size_x
*
micro_size_k
)
//
warp_size
local_size_b
=
(
micro_size_y
*
micro_size_k
)
//
warp_size
local_size_c
=
(
micro_size_x
*
micro_size_y
)
//
warp_size
warp_rows
=
warp_row_tiles
//
micro_size_x
warp_cols
=
warp_col_tiles
//
micro_size_y
# MMAC Wrapper to Auto Generate Code for MMAC
mmac_emitter
=
MatrixCoreIntrinEmitter
(
a_dtype
=
in_dtype
,
b_dtype
=
in_dtype
,
accum_dtype
=
accum_dtype
,
a_transposed
=
False
,
b_transposed
=
True
,
block_row_warps
=
block_row_warps
,
block_col_warps
=
block_col_warps
,
warp_row_tiles
=
warp_row_tiles
,
warp_col_tiles
=
warp_col_tiles
,
chunk
=
chunk
,
)
@
T
.
prim_func
def
gemm_intrinsics
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
A_local
=
T
.
alloc_local
((
warp_rows
*
local_size_a
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
# Load B into shared memory
for
j
,
k
in
T
.
Parallel
(
block_N
,
block_K
):
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mmac_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
# Load B into fragment
mmac_emitter
.
ldmatrix_b
(
B_local
,
B_shared
,
ki
)
# Perform Matrix Multiplication
mmac_emitter
.
mmac
(
A_local
,
B_local
,
C_local
)
# Perform STMatrix
mmac_emitter
.
stmatrix
(
C_local
,
C_shared
)
# Store shared into global
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
C_shared
[
j
//
micro_size_y
,
i
//
micro_size_x
,
i
%
micro_size_x
,
j
%
micro_size_y
,
]
return
gemm_intrinsics
def
ref_program
(
A
,
B
):
return
A
@
B
.
T
def
main
():
M
,
N
,
K
=
16384
,
16384
,
16384
in_dtype
,
out_dtype
,
accum_dtype
=
"float16"
,
"float16"
,
"float32"
kernel
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
)
src_code
=
kernel
.
get_kernel_source
()
# src_code is the generated cuda source
assert
src_code
is
not
None
profiler
=
kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
(
profiler
.
func
,
warmup
=
25
)
print
(
latency
)
print
(
kernel
.
get_kernel_source
())
# Ensure that the latency is not None
assert
latency
is
not
None
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
__name__
==
"__main__"
:
main
()
examples/minference/ops/vertical_slash_index.hip
0 → 100644
View file @
2c490782
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <assert.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <torch/extension.h>
#include <hip/hip_runtime.h>
__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[block_count++] = idx;
}
}
__global__ void convert_vertical_slash_indexes_kernel(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int N_HEADS,
int N_ROWS,
int BLOCK_SIZE_M,
int BLOCK_SIZE_N,
int NNZ_V,
int NNZ_S
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}
int end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
int s = 0, v = 0;
int v_idx = vertical_indexes[v++];
int s_idx = slash_indexes[s++];
while (s_idx >= end_m) {
s_idx = slash_indexes[s++];
}
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
v_idx = end_m + BLOCK_SIZE_M;
}
} else {
if (s < NNZ_S) {
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
break;
}
if (s_idx > range_end + BLOCK_SIZE_M) {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int BATCH_SIZE,
int N_HEADS,
int N_ROWS,
int NNZ_V,
int NNZ_S
) {
const int BLOCK_SIZE_M = 64;
const int BLOCK_SIZE_N = 64;
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, 0,
seqlens, vertical_indexes, slash_indexes,
block_count, block_offset, column_count, column_index,
N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
);
}
std::vector<at::Tensor> convert_vertical_slash_indexes(
torch::Tensor seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int context_size,
int block_size_M,
int block_size_N
) {
assert(block_size_M == 64);
assert(block_size_N == 64);
hipSetDevice(seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
convert_vertical_slash_indexes_64x64(
seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
nnz_vertical,
nnz_slash
);
return { block_count, block_offset, column_count, column_index };
}
src/layout/gemm_layouts.cc
View file @
2c490782
...
@@ -156,6 +156,23 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
...
@@ -156,6 +156,23 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
return
block_layout
;
return
block_layout
;
}
}
Fragment
makeGemmFragmentCDCU
(
const
int
block_m
,
const
int
block_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
)
{
if
(
element_size
==
64
)
LOG
(
FATAL
)
<<
"Not supported"
;
ICHECK
(
block_m
%
warp_m
==
0
);
ICHECK
(
block_n
%
warp_n
==
0
);
ICHECK
(
warp_m
%
16
==
0
)
<<
"warp_m="
<<
warp_m
;
ICHECK
(
warp_n
%
16
==
0
)
<<
"warp_n="
<<
warp_n
;
auto
base_layout
=
makeGemmFragmentC16x16CDNA
()
->
Repeat
({
1
,
1
},
false
);
auto
warp_layout
=
base_layout
->
Repeat
({
warp_m
/
16
,
warp_n
/
16
},
false
,
false
);
auto
block_layout
=
warp_layout
->
Repeat
({
block_m
/
warp_m
,
block_n
/
warp_n
},
true
,
true
);
return
block_layout
;
}
Fragment
makeGemmFragmentCCDNA
(
const
int
block_m
,
const
int
block_n
,
Fragment
makeGemmFragmentCCDNA
(
const
int
block_m
,
const
int
block_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
)
{
const
int
element_size
)
{
...
@@ -730,6 +747,7 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
...
@@ -730,6 +747,7 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
if
(
!
k_inner
&&
element_size
==
8
)
// int8 KxN
if
(
!
k_inner
&&
element_size
==
8
)
// int8 KxN
return
makeGemmABLayoutPadded
(
mat_stride
,
mat_continuous
,
element_size
);
return
makeGemmABLayoutPadded
(
mat_stride
,
mat_continuous
,
element_size
);
else
if
(
mat_continuous
%
(
vector_size
*
8
)
==
0
)
else
if
(
mat_continuous
%
(
vector_size
*
8
)
==
0
)
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
return
makeFullBankSwizzleLayout
(
mat_stride
,
mat_continuous
,
element_size
);
return
makeFullBankSwizzleLayout
(
mat_stride
,
mat_continuous
,
element_size
);
else
if
(
mat_continuous
%
(
vector_size
*
4
)
==
0
)
else
if
(
mat_continuous
%
(
vector_size
*
4
)
==
0
)
return
makeHalfBankSwizzleLayout
(
mat_stride
,
mat_continuous
,
element_size
);
return
makeHalfBankSwizzleLayout
(
mat_stride
,
mat_continuous
,
element_size
);
...
...
src/layout/layout.h
View file @
2c490782
...
@@ -150,6 +150,9 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
...
@@ -150,6 +150,9 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
Fragment
makeGemmFragmentCCDNA
(
const
int
block_m
,
const
int
block_n
,
Fragment
makeGemmFragmentCCDNA
(
const
int
block_m
,
const
int
block_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
);
const
int
element_size
);
Fragment
makeGemmFragmentCDCU
(
const
int
block_m
,
const
int
block_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
);
Fragment
makeGemmFragmentCHopper
(
const
int
block_m
,
const
int
block_n
,
Fragment
makeGemmFragmentCHopper
(
const
int
block_m
,
const
int
block_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
);
const
int
element_size
);
...
...
src/op/gemm.cc
View file @
2c490782
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
*/
*/
#include "gemm.h"
#include "gemm.h"
#include <fstream>
#include "builtin.h"
#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op.h>
...
@@ -828,9 +828,16 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
...
@@ -828,9 +828,16 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
<<
"CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<<
"CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<<
C
.
scope
();
<<
C
.
scope
();
if
(
TargetIsDCU
(
T
.
target
))
{
auto
fragment
=
makeGemmFragmentCDCU
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
auto
fragment
=
auto
fragment
=
makeGemmFragmentCCDNA
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
makeGemmFragmentCCDNA
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
}
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
int
dim_A
=
A
->
shape
.
size
();
...
...
src/target/codegen_hip.cc
View file @
2c490782
...
@@ -137,6 +137,7 @@ void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
...
@@ -137,6 +137,7 @@ void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
std
::
string
CodeGenTileLangHIP
::
Finish
()
{
std
::
string
CodeGenTileLangHIP
::
Finish
()
{
// hip must need a header file.
// hip must need a header file.
decl_stream
<<
"#define HIP_ENABLE_WARP_SYNC_BUILTINS
\n
"
;
decl_stream
<<
"#include <hip/hip_runtime.h>
\n
"
;
decl_stream
<<
"#include <hip/hip_runtime.h>
\n
"
;
if
(
need_mma_h_
)
{
if
(
need_mma_h_
)
{
decl_stream
<<
"#include <mma.h>
\n
"
;
decl_stream
<<
"#include <mma.h>
\n
"
;
...
@@ -146,12 +147,12 @@ std::string CodeGenTileLangHIP::Finish() {
...
@@ -146,12 +147,12 @@ std::string CodeGenTileLangHIP::Finish() {
decl_stream
<<
"#include <tl_templates/hip/hip_fp8.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/hip/hip_fp8.h>
\n
"
;
}
}
decl_stream
<<
"#include <tl_templates/hip/gemm.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/
dcu_
hip/gemm.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/hip/copy.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/
dcu_
hip/copy.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/hip/reduce.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/
dcu_
hip/reduce.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/hip/ldsm.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/
dcu_
hip/ldsm.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/hip/threadblock_swizzle.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/
dcu_
hip/threadblock_swizzle.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/hip/debug.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/
dcu_
hip/debug.h>
\n
"
;
decl_stream
<<
"
\n
"
;
decl_stream
<<
"
\n
"
;
return
CodeGenC
::
Finish
();
return
CodeGenC
::
Finish
();
}
}
...
@@ -952,6 +953,71 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
...
@@ -952,6 +953,71 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer
.
register_rule
(
"{c_ref}"
,
c_ref
);
replacer
.
register_rule
(
"{c_ref}"
,
c_ref
);
replacer
.
register_rule
(
"{c_bias}"
,
c_bias
);
replacer
.
register_rule
(
"{c_bias}"
,
c_bias
);
os
<<
replacer
.
rewrite
(
call_mfma_code
);
os
<<
replacer
.
rewrite
(
call_mfma_code
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
tvm_mmac
()))
{
// arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: float16, float32, ...
// arg 4: B precision: float16, float32, ...
// arg 5: C precision: float32, float64, ...
// arg 6: A multiplicand
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 11: C accumulator index
ICHECK
(
op
->
args
.
size
()
==
12U
)
<<
"Invalid number of arguments for tvm_mfma"
;
std
::
string
prefix
=
Downcast
<
StringImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
A_layout
=
Downcast
<
StringImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
B_layout
=
Downcast
<
StringImm
>
(
op
->
args
[
2
])
->
value
;
std
::
string
A_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
3
])
->
value
;
std
::
string
B_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
4
])
->
value
;
std
::
string
C_dtype
=
Downcast
<
StringImm
>
(
op
->
args
[
5
])
->
value
;
std
::
string
a_ref
=
this
->
PrintExpr
(
op
->
args
[
6
]);
std
::
string
a_bias
=
this
->
PrintExpr
(
op
->
args
[
7
]);
std
::
string
b_ref
=
this
->
PrintExpr
(
op
->
args
[
8
]);
std
::
string
b_bias
=
this
->
PrintExpr
(
op
->
args
[
9
]);
std
::
string
c_ref
=
this
->
PrintExpr
(
op
->
args
[
10
]);
std
::
string
c_bias
=
this
->
PrintExpr
(
op
->
args
[
11
]);
ICHECK
(
A_layout
==
"row"
||
B_layout
==
"row"
)
<<
"Matrix core only support row major"
;
// map for dtype -> float32x4 -> float4
std
::
unordered_map
<
std
::
string
,
std
::
string
>
dtype_map
=
{
{
"int8"
,
"char"
},
{
"int32"
,
"int"
},
{
"int8x4"
,
"int32_t"
},
{
"int8x8"
,
"int64_t"
},
{
"int32x4"
,
"int32x4"
},
{
"float16"
,
"half"
},
{
"float32"
,
"float"
},
{
"float64"
,
"double"
},
{
"float16x4"
,
"float16x4"
},
{
"bfloat16x4"
,
"bfloat16x4"
},
{
"float32x4"
,
"float32x4"
},
{
"float8_e4m3fnuzx4"
,
"fp8_e4_4_t"
},
{
"float8_e4m3fnuzx8"
,
"long"
},
{
"float32x16"
,
"float32x16"
}};
std
::
string
call_mmac_code
=
R"({
*((({C_dtype}*){c_ref}) + {c_bias}) = {mmac_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
*((({B_dtype}*){b_ref}) + {b_bias}),
*((({C_dtype}*){c_ref}) + {c_bias}));
})"
;
std
::
string
mmac_buildin
=
"__builtin_amdgcn_mmac_"
+
prefix
;
Replacer
replacer
;
replacer
.
register_rule
(
"{mmac_buildin}"
,
mmac_buildin
);
replacer
.
register_rule
(
"{A_dtype}"
,
dtype_map
[
A_dtype
]);
replacer
.
register_rule
(
"{B_dtype}"
,
dtype_map
[
B_dtype
]);
replacer
.
register_rule
(
"{C_dtype}"
,
dtype_map
[
C_dtype
]);
replacer
.
register_rule
(
"{a_ref}"
,
a_ref
);
replacer
.
register_rule
(
"{a_bias}"
,
a_bias
);
replacer
.
register_rule
(
"{b_ref}"
,
b_ref
);
replacer
.
register_rule
(
"{b_bias}"
,
b_bias
);
replacer
.
register_rule
(
"{c_ref}"
,
c_ref
);
replacer
.
register_rule
(
"{c_bias}"
,
c_bias
);
os
<<
replacer
.
rewrite
(
call_mmac_code
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
thread_return
()))
{
}
else
if
(
op
->
op
.
same_as
(
builtin
::
thread_return
()))
{
os
<<
"return"
;
os
<<
"return"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
tl_gemm
()))
{
}
else
if
(
op
->
op
.
same_as
(
tl
::
tl_gemm
()))
{
...
...
src/target/intrin_rule_hip.cc
View file @
2c490782
...
@@ -240,6 +240,16 @@ TVM_REGISTER_OP("tir.fmod")
...
@@ -240,6 +240,16 @@ TVM_REGISTER_OP("tir.fmod")
DispatchPureExtern
<
HIPMath
>
);
DispatchPureExtern
<
HIPMath
>
);
// Register low-level builtin ops.
// Register low-level builtin ops.
TVM_REGISTER_OP
(
"tir.hip.__shfl"
)
.
set_num_inputs
(
3
)
.
add_argument
(
"var"
,
"Expr"
,
"Value to shuffle"
)
.
add_argument
(
"lane"
,
"Expr"
,
"Source lane"
)
.
add_argument
(
"width"
,
"Expr"
,
"Warp width"
)
.
set_attr
<
TGlobalSymbol
>
(
"TGlobalSymbol"
,
"__shfl"
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TVM_REGISTER_OP
(
"tir.hip.__shfl_sync"
)
TVM_REGISTER_OP
(
"tir.hip.__shfl_sync"
)
.
set_num_inputs
(
4
)
.
set_num_inputs
(
4
)
.
add_argument
(
"mask"
,
"Expr"
,
"The thread mask."
)
.
add_argument
(
"mask"
,
"Expr"
,
"The thread mask."
)
...
@@ -286,4 +296,4 @@ TVM_REGISTER_OP("tir.hip.__activemask")
...
@@ -286,4 +296,4 @@ TVM_REGISTER_OP("tir.hip.__activemask")
}
// namespace intrin
}
// namespace intrin
}
// namespace codegen
}
// namespace codegen
}
// namespace tvm
}
// namespace tvm
\ No newline at end of file
src/target/utils.cc
View file @
2c490782
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "utils.h"
#include "utils.h"
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
...
@@ -78,6 +79,16 @@ bool TargetIsCDNA(Target target) {
...
@@ -78,6 +79,16 @@ bool TargetIsCDNA(Target target) {
return
false
;
return
false
;
}
}
bool
TargetIsDCU
(
Target
target
)
{
if
(
!
TargetIsRocm
(
target
))
return
false
;
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
// if mcpu start with "gfx936", it is DCU
return
mcpu
.
find
(
"gfx936"
)
==
0
;
}
return
false
;
}
bool
TargetHasAsyncCopy
(
Target
target
)
{
bool
TargetHasAsyncCopy
(
Target
target
)
{
if
(
TargetIsCuda
(
target
))
{
if
(
TargetIsCuda
(
target
))
{
int
arch
=
GetArchInt
(
target
);
int
arch
=
GetArchInt
(
target
);
...
...
src/target/utils.h
View file @
2c490782
...
@@ -22,6 +22,7 @@ bool TargetIsHopper(Target target);
...
@@ -22,6 +22,7 @@ bool TargetIsHopper(Target target);
bool
TargetIsSm100
(
Target
target
);
bool
TargetIsSm100
(
Target
target
);
bool
TargetIsSM120
(
Target
target
);
bool
TargetIsSM120
(
Target
target
);
bool
TargetIsCDNA
(
Target
target
);
bool
TargetIsCDNA
(
Target
target
);
bool
TargetIsDCU
(
Target
target
);
bool
TargetHasAsyncCopy
(
Target
target
);
bool
TargetHasAsyncCopy
(
Target
target
);
bool
TargetHasLdmatrix
(
Target
target
);
bool
TargetHasLdmatrix
(
Target
target
);
...
...
src/tl_templates/dcu_hip/common.h
0 → 100644
View file @
2c490782
#pragma once
#include "core.hpp"
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
// #include <rocwmma/rocwmma.hpp>
#define HIPRT_INF_F __int_as_float(0x7f800000)
#define HIPRT_NEGINF_F __int_as_float(0xff800000)
#define HIPRT_NAN_F __int_as_float(0x7fffffff)
#define HIPRT_MIN_DENORM_F __int_as_float(0x00000001)
#define HIPRT_MAX_NORMAL_F __int_as_float(0x7f7fffff)
#define HIPRT_NEG_ZERO_F __int_as_float(0x80000000)
#define HIPRT_ZERO_F 0.0f
#define HIPRT_ONE_F 1.0f
/* double precision constants */
#define HIPRT_INF __hiloint2double(0x7ff00000, 0x00000000)
#define HIPRT_NAN __hiloint2double(0xfff80000, 0x00000000)
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define TL_DEVICE __forceinline__ __device__
#define TL_DEVICE_NOINLINE __noinline__ __device__
#define TILELANG_CHECK(stmt) \
do { \
hipError_t __err = (stmt); \
if (__err != hipSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \
__LINE__, hipGetErrorName(__err), hipGetErrorString(__err)); \
return -1; \
} \
} while (0)
#define TILELANG_CHECK_LAST_ERROR(kernel_name) \
do { \
hipError_t __err = hipGetLastError(); \
if (__err != hipSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "kernel_name: %s - %s", \
hipGetErrorName(__err), hipGetErrorString(__err)); \
return -1; \
} \
} while (0)
#define half _Float16
#define __float2half_rn(x) half(x)
#define hpow __ocml_pown_f16
#define hsqrt __ocml_sqrt_f16
using
float16_t
=
_Float16
;
using
float16x2
=
__attribute__
((
__vector_size__
(
2
*
sizeof
(
float16_t
))))
float16_t
;
using
float16x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
float16_t
))))
float16_t
;
using
float16x8
=
__attribute__
((
__vector_size__
(
8
*
sizeof
(
float16_t
))))
float16_t
;
using
float16x16
=
__attribute__
((
__vector_size__
(
16
*
sizeof
(
float16_t
))))
float16_t
;
using
half_t
=
float16_t
;
using
bfloat16_t
=
__hip_bfloat16
;
struct
bfloat16x2
{
bfloat16_t
x
,
y
;
};
struct
bfloat16x4
{
bfloat16_t
data
[
4
];
};
struct
bfloat16x8
{
bfloat16_t
data
[
8
];
};
struct
bfloat16x16
{
bfloat16_t
data
[
16
];
};
typedef
__attribute__
((
__vector_size__
(
4
*
sizeof
(
short
))))
short
bfloat16x4_vec
;
using
int32x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
int
))))
int
;
using
float32x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
float
))))
float
;
using
float32x16
=
__attribute__
((
__vector_size__
(
16
*
sizeof
(
float
))))
float
;
using
int8x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
int8_t
))))
int8_t
;
// Pack two half_t values.
TL_DEVICE
unsigned
__pack_half2
(
const
half_t
x
,
const
half_t
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
// Pack two bfloat16_t values.
TL_DEVICE
unsigned
__pack_bfloat162
(
const
bfloat16_t
x
,
const
bfloat16_t
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
template
<
typename
T
>
struct
is_half_type
:
std
::
false_type
{};
template
<
>
struct
is_half_type
<
__half
>
:
std
::
true_type
{};
template
<
>
struct
is_half_type
<
half_t
>
:
std
::
true_type
{};
template
<
typename
T
>
inline
constexpr
bool
is_half_v
=
is_half_type
<
std
::
decay_t
<
T
>>::
value
;
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
void
AtomicAdd
(
T1
*
address
,
T2
val
)
{
if
constexpr
(
is_half_v
<
T1
>
)
{
__half
*
addr
=
reinterpret_cast
<
__half
*>
(
address
);
__half
hval
=
__float2half
(
static_cast
<
float
>
(
val
));
atomicAdd
(
addr
,
hval
);
}
else
{
atomicAdd
(
address
,
static_cast
<
T1
>
(
val
));
}
}
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
void
AtomicAdd
(
T1
&
ref
,
T2
val
)
{
AtomicAdd
(
&
ref
,
val
);
}
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
T1
AtomicAddRet
(
T1
&
ref
,
T2
val
)
{
return
atomicAdd
(
&
ref
,
static_cast
<
T1
>
(
val
));
}
template
<
typename
T
>
TL_DEVICE
void
AtomicAddx4
(
T
*
ref
,
const
T
val
[
4
])
{
atomicAdd
(
&
ref
[
0
],
val
[
0
]);
atomicAdd
(
&
ref
[
1
],
val
[
1
]);
atomicAdd
(
&
ref
[
2
],
val
[
2
]);
atomicAdd
(
&
ref
[
3
],
val
[
3
]);
}
\ No newline at end of file
src/tl_templates/dcu_hip/copy.h
0 → 100644
View file @
2c490782
#pragma once
#include "common.h"
using
f32
=
float
;
// using f16 = _Float16;
using
u8
=
std
::
uint8_t
;
using
u16
=
std
::
uint16_t
;
using
u32
=
std
::
uint32_t
;
using
index_t
=
u32
;
using
ck_tile
::
int32x4_t
;
struct
__attribute__
((
packed
))
buffer_resource
{
const
void
*
ptr
;
uint32_t
range
;
uint32_t
config
;
};
CK_TILE_DEVICE
int32x4_t
make_wave_buffer_resource
(
const
void
*
ptr
,
uint32_t
size
=
0xffffffff
)
{
buffer_resource
res
{
ptr
,
size
,
CK_TILE_BUFFER_RESOURCE_3RD_DWORD
};
int32x4_t
r
=
__builtin_bit_cast
(
int32x4_t
,
res
);
r
.
x
=
__builtin_amdgcn_readfirstlane
(
r
.
x
);
r
.
y
=
__builtin_amdgcn_readfirstlane
(
r
.
y
);
r
.
z
=
__builtin_amdgcn_readfirstlane
(
r
.
z
);
r
.
w
=
__builtin_amdgcn_readfirstlane
(
r
.
w
);
return
r
;
}
__device__
void
init_m0
(
uint32_t
m0_value
)
{
asm
volatile
(
"s_mov_b32 m0, %0"
:
:
"s"
(
m0_value
)
:
"memory"
);
}
__device__
void
inc_m0
(
uint32_t
m0_inc
)
{
asm
volatile
(
"s_add_u32 m0, %0, m0"
:
:
"n"
(
m0_inc
)
:
"memory"
);
}
namespace
tl
{
// AMDGPU automatically commit memory fence
TL_DEVICE
void
cp_async_commit
()
{}
// Global Memory only fence
__device__
void
async_gld_fence
(
index_t
cnt
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
// Global Memory and Shared Memory fence
__device__
void
async_gld_sld_fence
(
index_t
cnt
)
{
asm
volatile
(
"s_waitcnt lgkmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
__device__
void
wave_barrier
()
{
asm
volatile
(
"s_barrier"
:
:
:
"memory"
);
}
template
<
int
N
=
0
>
TL_DEVICE
void
cp_async_wait
()
{
async_gld_fence
(
N
);
// or
// async_gld_sld_fence(N);
}
template
<
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
int32x4_t
rsrc
,
index_t
voffset
)
{
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
smem
)));
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
voffset
),
"s"
(
rsrc
)
:
"memory"
);
}
template
<
int
N
>
TL_DEVICE
void
cp_async_gs
(
void
*
lds_base_ptr
,
void
*
global_base_ptr
)
{
if
constexpr
(
N
==
16
)
{
*
(
uint4
*
)
lds_base_ptr
=
*
(
uint4
*
)
global_base_ptr
;
}
else
if
constexpr
(
N
==
8
)
{
*
(
uint2
*
)
lds_base_ptr
=
*
(
uint2
*
)
global_base_ptr
;
}
else
if
constexpr
(
N
==
4
)
{
async_buffer_load_dword_v
(
lds_base_ptr
,
make_wave_buffer_resource
(((
int32_t
*
)
global_base_ptr
)
-
threadIdx
.
x
),
threadIdx
.
x
*
N
/*assume 4 bytes*/
);
}
}
template
<
int
N
>
TL_DEVICE
void
cp_async_gs_conditional
(
void
*
lds_base_ptr
,
void
*
global_base_ptr
,
bool
cond
)
{
if
constexpr
(
N
==
16
)
{
*
(
uint4
*
)
lds_base_ptr
=
cond
?
*
(
uint4
*
)
global_base_ptr
:
make_uint4
(
0
,
0
,
0
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
*
(
uint2
*
)
lds_base_ptr
=
cond
?
*
(
uint2
*
)
global_base_ptr
:
make_uint2
(
0
,
0
);
}
else
{
if
(
cond
)
{
async_buffer_load_dword_v
(
lds_base_ptr
,
make_wave_buffer_resource
(((
int32_t
*
)
global_base_ptr
)
-
threadIdx
.
x
),
threadIdx
.
x
*
N
/*assume 4 bytes*/
);
}
else
{
*
(
uint4
*
)
lds_base_ptr
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
}
}
// namespace tl
src/tl_templates/dcu_hip/core.hpp
0 → 100644
View file @
2c490782
#ifdef __HIPCC__
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
defined(__gfx9__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#else
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
namespace
ck_tile
{
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
max
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
}
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
}
template
<
>
CK_TILE_DEVICE
float
max
(
float
x
,
float
y
)
{
return
__builtin_fmaxf
(
x
,
y
);
// can resultin v_max3_f32
}
template
<
>
CK_TILE_DEVICE
double
max
(
double
x
,
double
y
)
{
return
__builtin_fmax
(
x
,
y
);
// maybe still v_max3_f32
}
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
max
(
X
x
,
Ys
...
ys
)
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
max
(
x
,
max
(
ys
...));
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
min
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
>
CK_TILE_DEVICE
float
min
(
float
x
,
float
y
)
{
return
__builtin_fminf
(
x
,
y
);
}
template
<
>
CK_TILE_DEVICE
double
min
(
double
x
,
double
y
)
{
return
__builtin_fmin
(
x
,
y
);
}
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
min
(
X
x
,
Ys
...
ys
)
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
min
(
x
,
min
(
ys
...));
}
}
src/tl_templates/dcu_hip/debug.h
0 → 100644
View file @
2c490782
#pragma once
#include <hip/hip_runtime.h>
// Base template declaration
template
<
typename
T
>
__device__
void
debug_print_var
(
const
char
*
msg
,
T
var
);
// Specialization for signed char type
template
<
>
__device__
void
debug_print_var
<
signed
char
>
(
const
char
*
msg
,
signed
char
var
)
{
const
char
*
safe_msg
=
msg
;
int
value
=
static_cast
<
int
>
(
var
);
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed "
"char value=%d
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
value
);
}
// Specialization for unsigned char type
template
<
>
__device__
void
debug_print_var
<
unsigned
char
>
(
const
char
*
msg
,
unsigned
char
var
)
{
const
char
*
safe_msg
=
msg
;
unsigned
int
value
=
static_cast
<
unsigned
int
>
(
var
);
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned char value=%u
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
value
);
}
// Specialization for int type
template
<
>
__device__
void
debug_print_var
<
int
>
(
const
char
*
msg
,
int
var
)
{
const
char
*
safe_msg
=
msg
;
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
"value=%d
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
var
);
}
// Specialization for unsigned int type
template
<
>
__device__
void
debug_print_var
<
unsigned
int
>
(
const
char
*
msg
,
unsigned
int
var
)
{
const
char
*
safe_msg
=
msg
;
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned int value=%u
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
var
);
}
// Specialization for float type
template
<
>
__device__
void
debug_print_var
<
float
>
(
const
char
*
msg
,
float
var
)
{
const
char
*
safe_msg
=
msg
;
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float "
"value=%f
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
var
);
}
// Specialization for double type
template
<
>
__device__
void
debug_print_var
<
double
>
(
const
char
*
msg
,
double
var
)
{
const
char
*
safe_msg
=
msg
;
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double "
"value=%lf
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
var
);
}
// Specialization for bool type
template
<
>
__device__
void
debug_print_var
<
bool
>
(
const
char
*
msg
,
bool
var
)
{
const
char
*
safe_msg
=
msg
;
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool "
"value=%s
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
var
?
"true"
:
"false"
);
}
// Specialization for short type
template
<
>
__device__
void
debug_print_var
<
short
>
(
const
char
*
msg
,
short
var
)
{
const
char
*
safe_msg
=
msg
;
int
value
=
static_cast
<
int
>
(
var
);
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=short "
"value=%d
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
value
);
}
// Specialization for unsigned short type
template
<
>
__device__
void
debug_print_var
<
unsigned
short
>
(
const
char
*
msg
,
unsigned
short
var
)
{
const
char
*
safe_msg
=
msg
;
unsigned
int
value
=
static_cast
<
unsigned
int
>
(
var
);
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned short value=%u
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
value
);
}
// Template declaration for device-side debug printing (buffer only)
template
<
typename
T
>
__device__
void
debug_print_buffer_value
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
T
var
);
// Specialization for signed char type
template
<
>
__device__
void
debug_print_buffer_value
<
signed
char
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
signed
char
var
)
{
const
char
*
safe_msg
=
msg
;
const
char
*
safe_buf_name
=
buf_name
;
int
value
=
static_cast
<
int
>
(
var
);
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=signed char value=%d
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
safe_buf_name
,
index
,
value
);
}
// Specialization for unsigned char type
template
<
>
__device__
void
debug_print_buffer_value
<
unsigned
char
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
unsigned
char
var
)
{
const
char
*
safe_msg
=
msg
;
const
char
*
safe_buf_name
=
buf_name
;
unsigned
int
value
=
static_cast
<
unsigned
int
>
(
var
);
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=unsigned char value=%u
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
safe_buf_name
,
index
,
value
);
}
// Specialization for integer type
template
<
>
__device__
void
debug_print_buffer_value
<
int
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
int
var
)
{
const
char
*
safe_msg
=
msg
;
const
char
*
safe_buf_name
=
buf_name
;
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int value=%d
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
safe_buf_name
,
index
,
var
);
}
// Specialization for float type
template
<
>
__device__
void
debug_print_buffer_value
<
float
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
float
var
)
{
const
char
*
safe_msg
=
msg
;
const
char
*
safe_buf_name
=
buf_name
;
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=float value=%f
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
safe_buf_name
,
index
,
var
);
}
// Specialization for half_t type
template
<
>
__device__
void
debug_print_buffer_value
<
half_t
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
half_t
var
)
{
const
char
*
safe_msg
=
msg
;
const
char
*
safe_buf_name
=
buf_name
;
float
value
=
static_cast
<
float
>
(
var
);
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=half_t value=%f
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
safe_buf_name
,
index
,
value
);
}
// Specialization for double type
template
<
>
__device__
void
debug_print_buffer_value
<
double
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
double
var
)
{
const
char
*
safe_msg
=
msg
;
const
char
*
safe_buf_name
=
buf_name
;
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=double value=%lf
\n
"
,
safe_msg
,
(
int
)
blockIdx
.
x
,
(
int
)
blockIdx
.
y
,
(
int
)
blockIdx
.
z
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
safe_buf_name
,
index
,
var
);
}
src/tl_templates/dcu_hip/gemm.h
0 → 100644
View file @
2c490782
#pragma once
#include "common.h"
#include <type_traits>
namespace
tl
{
// Trait to determine the MFMA instruction to use based on data type
template
<
typename
T
>
struct
MfmaTraits
;
// Specialization for int8
template
<
>
struct
MfmaTraits
<
int8_t
>
{
template
<
typename
AccType
>
static
TL_DEVICE
void
mfma_op
(
const
int8_t
*
b
,
const
int8_t
*
a
,
AccType
*
c
)
{
int64_t
*
b_packed
=
reinterpret_cast
<
int64_t
*>
(
const_cast
<
int8_t
*>
(
b
));
int64_t
*
a_packed
=
reinterpret_cast
<
int64_t
*>
(
const_cast
<
int8_t
*>
(
a
));
*
c
=
__builtin_amdgcn_mmac_i32_16x16x32i8
(
*
b_packed
,
*
a_packed
,
*
c
);
}
};
// Specialization for half/float16
template
<
>
struct
MfmaTraits
<
half
>
{
template
<
typename
AccType
>
static
TL_DEVICE
void
mfma_op
(
const
half
*
b
,
const
half
*
a
,
AccType
*
c
)
{
*
c
=
__builtin_amdgcn_mmac_f32_16x16x16f16
(
*
((
float16x4
*
)
b
),
*
((
float16x4
*
)
a
),
*
c
);
}
};
// Specialization for bfloat16_t
template
<
>
struct
MfmaTraits
<
bfloat16_t
>
{
template
<
typename
AccType
>
static
TL_DEVICE
void
mfma_op
(
const
bfloat16_t
*
b
,
const
bfloat16_t
*
a
,
AccType
*
c
)
{
bfloat16x4_vec
b_vec
,
a_vec
;
// Reinterpret the pointers
short
*
b_short
=
reinterpret_cast
<
short
*>
(
const_cast
<
bfloat16_t
*>
(
b
));
short
*
a_short
=
reinterpret_cast
<
short
*>
(
const_cast
<
bfloat16_t
*>
(
a
));
// Copy the data
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
b_vec
[
i
]
=
b_short
[
i
];
a_vec
[
i
]
=
a_short
[
i
];
}
// Call the intrinsic and store the result directly to c
*
c
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
b_vec
,
a_vec
,
*
c
);
}
};
#if defined(HIP_FP8_ENABLED)
// Specialization for fp8_e4_t
template
<
>
struct
MfmaTraits
<
fp8_e4_t
>
{
template
<
typename
AccType
>
static
TL_DEVICE
void
mfma_op
(
const
fp8_e4_t
*
b
,
const
fp8_e4_t
*
a
,
AccType
*
c
)
{
int64_t
a_val
=
*
reinterpret_cast
<
const
int64_t
*>
(
a
);
int64_t
b_val
=
*
reinterpret_cast
<
const
int64_t
*>
(
b
);
*
c
=
__builtin_amdgcn_mmac_f32_16x16x32_fp8_fp8
(
b_val
,
a_val
,
*
c
);
}
};
#endif
// ref to bitblas/tl/mfma_macro_generator.py::kPack
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_n
,
int
num_warp_m
,
bool
TransposeA
,
bool
TransposeB
,
bool
clear_accum
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
,
typename
AccDataType
=
float
>
class
GemmTensorOp
{
public:
//static_assert(!clear_accum, "clear_accum=true is not supported yet");
static
constexpr
int
micro_size_x
=
16
;
static
constexpr
int
micro_size_y
=
16
;
static
constexpr
int
micro_size_k
=
32
/
sizeof
(
A_type
);
static
constexpr
int
vec_size
=
8
/
sizeof
(
A_type
);
// This part comes from the Codegen
static
constexpr
int
M_Tile
=
N
;
static
constexpr
int
N_Tile
=
M
;
static
constexpr
int
K_Tile
=
K
;
static
constexpr
int
block_row_warps
=
num_warp_m
;
static
constexpr
int
block_col_warps
=
num_warp_n
;
static
constexpr
int
inner_k
=
K_Tile
/
(
micro_size_k
*
kPack
);
static
constexpr
int
warp_rows
=
M_Tile
/
(
block_row_warps
*
micro_size_x
);
static
constexpr
int
warp_cols
=
N_Tile
/
(
block_col_warps
*
micro_size_y
);
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen
// part.
static
constexpr
bool
kPadA
=
true
;
static
constexpr
bool
kPadB
=
true
;
static
constexpr
bool
kPadC
=
true
;
static
constexpr
int
BANK_SIZE_BYTES
=
128
;
static
constexpr
int
warp_size
=
64
;
TL_DEVICE
static
constexpr
auto
reverse_index_map
(
int
thread_id
,
int
local_id
)
{
return
std
::
make_pair
(
thread_id
%
16
,
(
thread_id
/
16
)
*
(
vec_size
*
kPack
)
+
local_id
);
}
TL_DEVICE
static
constexpr
auto
reverse_index_map_transposed
(
int
thread_id
,
int
local_id
)
{
return
std
::
make_pair
((
thread_id
/
16
)
*
(
vec_size
*
kPack
)
+
local_id
,
thread_id
%
16
);
}
/*
* Detailed Implementation please
* checkout bitblas/tl/utils.py:get_swizzle_layout
*/
template
<
int
continuous
=
32
,
int
element_size
=
2
>
TL_DEVICE
static
auto
make_mfma_swizzle_layout
(
const
int
row
,
const
int
col
)
{
const
auto
dtype_bits
=
element_size
*
8
;
const
int
numBanks
=
32
;
const
int
bankBitWidth
=
32
;
const
int
SIMDWidth
=
16
;
const
int
vecSize
=
vec_size
*
kPack
;
const
int
innerDimLength
=
continuous
;
const
int
typeWidthInBit
=
dtype_bits
;
const
int
elemsPerOneBanksRow
=
(
numBanks
*
bankBitWidth
)
/
typeWidthInBit
;
const
int
perPhase
=
std
::
max
(
1
,
elemsPerOneBanksRow
/
innerDimLength
);
const
int
maxPhase
=
std
::
min
(
SIMDWidth
/
perPhase
,
innerDimLength
/
vecSize
);
const
int
phase
=
(
row
/
perPhase
)
%
maxPhase
;
const
int
colOffSwizzled
=
(((
col
/
vecSize
)
^
phase
)
*
vecSize
);
const
int
colOffOrdered
=
col
%
vecSize
;
const
int
colOff
=
colOffSwizzled
+
colOffOrdered
;
return
std
::
make_pair
(
row
,
colOff
);
}
template
<
int
continuous
=
32
,
int
element_size
=
2
>
TL_DEVICE
static
constexpr
auto
make_layout_padded
(
const
int
row
,
const
int
col
)
{
return
std
::
make_pair
(
row
,
col
);
}
template
<
int
continuous
=
32
,
int
element_size
=
2
>
TL_DEVICE
static
constexpr
auto
make_swizzle_layout
(
const
int
row
,
const
int
col
)
{
auto
[
n_row
,
n_col
]
=
make_mfma_swizzle_layout
<
continuous
,
element_size
>
(
row
,
col
);
return
n_row
*
continuous
+
n_col
;
}
static
TL_DEVICE
void
body
(
A_type
*
A_shared
,
B_type
*
B_shared
,
C_type
*
C_local
)
{
auto
tid
=
threadIdx
.
x
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_n
=
warp_id
/
block_row_warps
;
auto
warp_m
=
warp_id
%
block_row_warps
;
auto
warp_row_tiles
=
warp_rows
*
micro_size_x
;
auto
warp_col_tiles
=
warp_cols
*
micro_size_y
;
auto
lane_id
=
tid
%
warp_size
;
auto
tx
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
blane_id
=
(
lane_id
&
15
)
/
4
+
(
lane_id
&
3
)
*
4
+
(
lane_id
/
16
)
*
16
;
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_b
=
(
micro_size_y
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_c
=
(
micro_size_x
*
micro_size_y
)
/
warp_size
;
constexpr
auto
last_dim_b
=
TransposeB
?
K_Tile
:
M_Tile
;
constexpr
auto
last_dim_a
=
TransposeA
?
N_Tile
:
K_Tile
;
B_type
B_local
[
warp_rows
*
kPack
*
local_size_b
];
A_type
A_local
[
warp_cols
*
kPack
*
local_size_a
];
for
(
int
ki
=
0
;
ki
<
inner_k
;
ki
++
)
{
// Fetch B into register
for
(
int
i
=
0
;
i
<
warp_rows
;
i
++
)
{
const
auto
l
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
;
const
auto
r
=
ki
*
(
kPack
*
micro_size_k
);
for
(
int
local_id
=
0
;
local_id
<
(
kPack
*
local_size_b
);
local_id
++
)
{
if
constexpr
(
TransposeB
)
{
auto
[
row
,
col
]
=
reverse_index_map
(
blane_id
,
local_id
);
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
l
+
row
,
r
+
col
)];
}
else
{
auto
[
row
,
col
]
=
reverse_index_map_transposed
(
blane_id
,
local_id
);
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
r
+
row
,
l
+
col
)];
}
}
}
// Fetch A into register
for
(
int
j
=
0
;
j
<
warp_cols
;
j
++
)
{
const
auto
l
=
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
;
const
auto
r
=
ki
*
(
kPack
*
micro_size_k
);
for
(
int
local_id
=
0
;
local_id
<
(
kPack
*
local_size_a
);
local_id
++
)
{
if
constexpr
(
TransposeA
)
{
auto
[
row
,
col
]
=
reverse_index_map_transposed
(
alane_id
,
local_id
);
A_local
[
j
*
kPack
*
local_size_a
+
local_id
]
=
A_shared
[
make_swizzle_layout
<
last_dim_a
,
sizeof
(
A_type
)
>
(
r
+
row
,
l
+
col
)];
}
else
{
auto
[
row
,
col
]
=
reverse_index_map
(
alane_id
,
local_id
);
A_local
[
j
*
kPack
*
local_size_a
+
local_id
]
=
A_shared
[
make_swizzle_layout
<
last_dim_a
,
sizeof
(
A_type
)
>
(
l
+
row
,
r
+
col
)];
}
}
}
// Compute
for
(
int
kp
=
0
;
kp
<
kPack
;
kp
++
)
{
for
(
int
i
=
0
;
i
<
warp_rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
warp_cols
;
++
j
)
{
auto
acc_ptr
=
((
float32x4
*
)
C_local
)
+
((
i
*
warp_cols
)
+
j
);
auto
a_ptr
=
((
A_type
*
)
A_local
)
+
(
j
*
kPack
+
kp
)
*
vec_size
;
auto
b_ptr
=
((
B_type
*
)
B_local
)
+
(
i
*
kPack
+
kp
)
*
vec_size
;
// Use the trait to select the correct MFMA instruction, either fp8,
// fp16 or bf16 currently
MfmaTraits
<
A_type
>::
mfma_op
(
a_ptr
,
b_ptr
,
acc_ptr
);
}
}
}
}
}
static
TL_DEVICE
void
body_rs
(
A_type
*
A_local
,
B_type
*
B_shared
,
C_type
*
C_local
)
{
auto
tid
=
threadIdx
.
x
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_n
=
warp_id
/
block_row_warps
;
auto
warp_m
=
warp_id
%
block_row_warps
;
auto
warp_row_tiles
=
warp_rows
*
micro_size_x
;
auto
warp_col_tiles
=
warp_cols
*
micro_size_y
;
auto
lane_id
=
tid
%
warp_size
;
auto
tx
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
blane_id
=
(
lane_id
&
15
)
/
4
+
(
lane_id
&
3
)
*
4
+
(
lane_id
/
16
)
*
16
;
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_b
=
(
micro_size_y
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_c
=
(
micro_size_x
*
micro_size_y
)
/
warp_size
;
constexpr
auto
last_dim_b
=
TransposeB
?
K_Tile
:
M_Tile
;
constexpr
auto
last_dim_a
=
TransposeA
?
N_Tile
:
K_Tile
;
B_type
B_local
[
warp_rows
*
kPack
*
local_size_b
];
for
(
int
ki
=
0
;
ki
<
inner_k
;
ki
++
)
{
// Fetch B into register
for
(
int
i
=
0
;
i
<
warp_rows
;
i
++
)
{
const
auto
l
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
;
const
auto
r
=
ki
*
(
kPack
*
micro_size_k
);
for
(
int
local_id
=
0
;
local_id
<
(
kPack
*
local_size_b
);
local_id
++
)
{
if
constexpr
(
TransposeB
)
{
auto
[
row
,
col
]
=
reverse_index_map
(
blane_id
,
local_id
);
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
l
+
row
,
r
+
col
)];
}
else
{
auto
[
row
,
col
]
=
reverse_index_map_transposed
(
blane_id
,
local_id
);
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
r
+
row
,
l
+
col
)];
}
}
}
// Compute
for
(
int
kp
=
0
;
kp
<
kPack
;
kp
++
)
{
for
(
int
i
=
0
;
i
<
warp_rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
warp_cols
;
++
j
)
{
auto
acc_ptr
=
((
float32x4
*
)
C_local
)
+
((
i
*
warp_cols
)
+
j
);
auto
b_ptr
=
((
B_type
*
)
B_local
)
+
(
i
*
kPack
+
kp
)
*
vec_size
;
auto
a_ptr
=
((
A_type
*
)
A_local
)
+
(
ki
*
warp_cols
*
kPack
+
j
*
kPack
+
kp
)
*
vec_size
;
// Use the trait to select the correct MFMA instruction, either fp8,
// fp16 or bf16 currently
MfmaTraits
<
A_type
>::
mfma_op
(
a_ptr
,
b_ptr
,
acc_ptr
);
}
}
}
}
}
};
}
// namespace tl
namespace
tl
{
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
bool
clear_accum
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
TL_DEVICE
void
gemm_ss
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
Compute
=
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
clear_accum
,
kPack
,
A_type
,
B_type
,
C_type
>
;
Compute
::
body
(
pA
,
pB
,
accum
);
}
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
bool
clear_accum
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
TL_DEVICE
void
gemm_rs
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
Compute
=
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
clear_accum
,
kPack
,
A_type
,
B_type
,
C_type
>
;
Compute
::
body_rs
(
pA
,
pB
,
accum
);
}
}
// namespace tl
src/tl_templates/dcu_hip/hip_fp8.h
0 → 100644
View file @
2c490782
#include <hip/amd_detail/amd_hip_fp8.h>
#define HIP_FP8_ENABLED 1
using
fp8_e4_t
=
__hip_fp8_e4m3_fnuz
;
using
fp8_e4_2_t
=
__hip_fp8x2_e4m3_fnuz
;
// Simple wrapper that provides member access for generated code
struct
fp8_e4_4_t
{
union
{
__hip_fp8x4_e4m3_fnuz
data
;
struct
{
fp8_e4_t
x
,
y
,
z
,
w
;
};
};
// Default constructor
__device__
fp8_e4_4_t
()
=
default
;
// Constructor from __hip_fp8x4_e4m3_fnuz
__device__
fp8_e4_4_t
(
const
__hip_fp8x4_e4m3_fnuz
&
val
)
:
data
(
val
)
{}
// Constructor from float4
__device__
fp8_e4_4_t
(
const
float4
&
val
)
:
data
(
val
)
{}
// Conversion operator to __hip_fp8x4_e4m3_fnuz
__device__
operator
__hip_fp8x4_e4m3_fnuz
()
const
{
return
data
;
}
// Assignment operator
__device__
fp8_e4_4_t
&
operator
=
(
const
__hip_fp8x4_e4m3_fnuz
&
val
)
{
data
=
val
;
return
*
this
;
}
};
struct
__align__
(
8
)
fp8_e4_8_t
{
fp8_e4_4_t
x
;
fp8_e4_4_t
y
;
};
struct
__align__
(
16
)
fp8_e4_16_t
{
fp8_e4_8_t
x
;
fp8_e4_8_t
y
;
};
__device__
fp8_e4_4_t
make_fp8_e4_4_t
(
fp8_e4_t
x
,
fp8_e4_t
y
,
fp8_e4_t
z
,
fp8_e4_t
w
)
{
// reinterpret the 4 fp8_e4_t values to signed char value and shift
signed
char
x_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
x
);
signed
char
y_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
y
);
signed
char
z_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
z
);
signed
char
w_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
w
);
int
res
=
(
w_char
<<
24
)
|
(
z_char
<<
16
)
|
(
y_char
<<
8
)
|
x_char
;
return
*
reinterpret_cast
<
fp8_e4_4_t
*>
(
&
res
);
}
__device__
fp8_e4_8_t
make_fp8_e4_8_t
(
fp8_e4_t
x
,
fp8_e4_t
y
,
fp8_e4_t
z
,
fp8_e4_t
w
,
fp8_e4_t
v
,
fp8_e4_t
u
,
fp8_e4_t
t
,
fp8_e4_t
s
)
{
signed
char
x_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
x
);
signed
char
y_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
y
);
signed
char
z_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
z
);
signed
char
w_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
w
);
signed
char
v_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
v
);
signed
char
u_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
u
);
signed
char
t_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
t
);
signed
char
s_char
=
*
reinterpret_cast
<
signed
char
*>
(
&
s
);
int
a
=
(
w_char
<<
24
)
|
(
z_char
<<
16
)
|
(
y_char
<<
8
)
|
x_char
;
int
b
=
(
s_char
<<
24
)
|
(
t_char
<<
16
)
|
(
u_char
<<
8
)
|
v_char
;
fp8_e4_8_t
res
;
res
.
x
=
*
reinterpret_cast
<
fp8_e4_4_t
*>
(
&
a
);
res
.
y
=
*
reinterpret_cast
<
fp8_e4_4_t
*>
(
&
b
);
return
res
;
}
src/tl_templates/dcu_hip/ldsm.h
0 → 100644
View file @
2c490782
#pragma once
#include "common.h"
src/tl_templates/dcu_hip/reduce.h
0 → 100644
View file @
2c490782
#pragma once
#include "common.h"
namespace
tl
{
struct
SumOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
struct
MaxOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
ck_tile
::
max
(
x
,
y
);
}
};
struct
MinOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
ck_tile
::
min
(
x
,
y
);
}
};
// Detect half types
template
<
typename
T
>
struct
is_half_type
:
std
::
false_type
{};
template
<
>
struct
is_half_type
<
__half
>
:
std
::
true_type
{};
template
<
>
struct
is_half_type
<
_Float16
>
:
std
::
true_type
{};
template
<
typename
T
>
inline
constexpr
bool
is_half_v
=
is_half_type
<
std
::
decay_t
<
T
>>::
value
;
template
<
class
Reducer
,
int
threads
,
int
scale
,
int
thread_offset
=
0
>
struct
AllReduce
{
static_assert
(
threads
==
1024
||
threads
==
512
||
threads
==
256
||
threads
==
128
||
threads
==
64
||
threads
==
32
||
threads
==
16
||
threads
==
8
||
threads
==
4
||
threads
==
2
);
static_assert
(
threads
%
scale
==
0
);
template
<
typename
T
>
static
__device__
T
run
(
T
x
,
T
*
red_buf
=
nullptr
)
{
constexpr
int
offset
=
threads
/
2
;
constexpr
int
warpSize
=
64
;
if
constexpr
(
offset
>=
warpSize
)
{
__syncthreads
();
red_buf
[
threadIdx
.
x
]
=
x
;
__syncthreads
();
x
=
Reducer
()(
x
,
red_buf
[
threadIdx
.
x
^
offset
]);
}
else
{
if
constexpr
(
is_half_v
<
T
>
)
{
unsigned
short
x_raw
;
if
constexpr
(
std
::
is_same_v
<
std
::
decay_t
<
T
>
,
__half
>
)
{
x_raw
=
__half_as_ushort
(
x
);
}
else
{
// _Float16
union
{
_Float16
f
;
unsigned
short
s
;
}
u
;
u
.
f
=
x
;
x_raw
=
u
.
s
;
}
unsigned
short
shuffled_raw
=
__shfl_xor
(
x_raw
,
offset
);
T
shuffled_x
;
if
constexpr
(
std
::
is_same_v
<
std
::
decay_t
<
T
>
,
__half
>
)
{
shuffled_x
=
__ushort_as_half
(
shuffled_raw
);
}
else
{
// _Float16
union
{
unsigned
short
s
;
_Float16
f
;
}
u
;
u
.
s
=
shuffled_raw
;
shuffled_x
=
u
.
f
;
}
x
=
Reducer
()(
x
,
shuffled_x
);
}
else
{
x
=
Reducer
()(
x
,
__shfl_xor
(
x
,
offset
));
}
}
if
constexpr
(
offset
==
scale
)
{
return
x
;
}
else
{
return
AllReduce
<
Reducer
,
offset
,
scale
,
thread_offset
>::
run
(
x
,
red_buf
);
}
}
};
template
<
int
threads
,
int
Axis
=
0
,
bool
reverse
=
false
>
struct
CumSum2D
{
static_assert
(
threads
==
1024
or
threads
==
512
or
threads
==
256
or
threads
==
128
or
threads
==
64
or
threads
==
32
);
template
<
typename
T
,
int
SEG
=
32
>
static
TL_DEVICE
T
run
(
const
T
*
__restrict__
src
,
T
*
__restrict__
dst
,
int
H
,
int
W
)
{
constexpr
int
TILE_H
=
threads
/
SEG
;
constexpr
uint64_t
MASK
=
0xffffffffffffffffULL
;
const
int
num_blocks
=
(
H
+
TILE_H
-
1
)
/
TILE_H
;
const
int
tid
=
threadIdx
.
x
;
const
int
lane
=
tid
%
64
;
const
int
row
=
tid
/
64
;
for
(
int
b
=
0
;
b
<
num_blocks
;
++
b
)
{
const
int
gRow
=
b
*
TILE_H
+
row
;
if
(
gRow
>=
H
)
return
;
T
carry
=
(
T
)
0
;
if
(
reverse
)
{
// Start from the last segment for reverse mode
for
(
int
seg
=
(
W
+
SEG
-
1
)
/
SEG
-
1
;
seg
>=
0
;
--
seg
)
{
const
int
col
=
seg
*
SEG
+
lane
;
const
int
real_row
=
Axis
==
1
?
gRow
:
col
;
const
int
real_col
=
Axis
==
1
?
col
:
gRow
;
T
val
=
(
col
<
W
)
?
src
[
real_row
*
W
+
real_col
]
:
(
T
)
0
;
#pragma unroll
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
T
n
=
(
T
)
__shfl_down_sync
(
MASK
,
val
,
off
);
if
(
lane
<
SEG
-
off
)
val
+=
n
;
}
val
+=
carry
;
if
(
real_col
<
W
)
dst
[
real_row
*
W
+
real_col
]
=
val
;
T
segSum
=
(
T
)
__shfl_sync
(
MASK
,
val
,
(
T
)
0
);
if
(
lane
==
0
)
carry
=
segSum
;
carry
=
(
T
)
__shfl_sync
(
MASK
,
carry
,
(
T
)
0
);
}
}
else
{
for
(
int
seg
=
0
;
seg
*
SEG
<
W
;
++
seg
)
{
const
int
col
=
seg
*
SEG
+
lane
;
const
int
real_row
=
Axis
==
1
?
gRow
:
col
;
const
int
real_col
=
Axis
==
1
?
col
:
gRow
;
T
val
=
(
col
<
W
)
?
src
[
real_row
*
W
+
real_col
]
:
(
T
)
0
;
#pragma unroll
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
T
n
=
(
T
)
__shfl_up_sync
(
MASK
,
val
,
off
);
if
(
lane
>=
off
)
val
+=
n
;
}
val
+=
carry
;
if
(
real_col
<
W
)
dst
[
real_row
*
W
+
real_col
]
=
val
;
T
segSum
=
(
T
)
__shfl_sync
(
MASK
,
val
,
SEG
-
1
);
if
(
lane
==
SEG
-
1
)
carry
=
segSum
;
carry
=
(
T
)
__shfl_sync
(
MASK
,
carry
,
SEG
-
1
);
}
}
}
}
};
}
// namespace tl
src/tl_templates/dcu_hip/threadblock_swizzle.h
0 → 100644
View file @
2c490782
#pragma once
#include "common.h"
namespace
tl
{
template
<
int
panel_width
>
TL_DEVICE
dim3
rasterization2DRow
()
{
auto
ceil_div
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
};
const
unsigned
int
block_idx
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
;
const
unsigned
int
grid_size
=
gridDim
.
x
*
gridDim
.
y
;
const
unsigned
int
panel_size
=
panel_width
*
gridDim
.
x
;
const
unsigned
int
panel_offset
=
block_idx
%
panel_size
;
const
unsigned
int
panel_idx
=
block_idx
/
panel_size
;
const
unsigned
int
total_panel
=
ceil_div
(
grid_size
,
panel_size
);
const
unsigned
int
stride
=
panel_idx
+
1
<
total_panel
?
panel_width
:
(
grid_size
-
panel_idx
*
panel_size
)
/
gridDim
.
x
;
const
unsigned
int
col_idx
=
(
panel_idx
&
1
)
?
gridDim
.
x
-
1
-
panel_offset
/
stride
:
panel_offset
/
stride
;
const
unsigned
int
row_idx
=
panel_offset
%
stride
+
panel_idx
*
panel_width
;
return
{
col_idx
,
row_idx
,
blockIdx
.
z
};
}
template
<
int
panel_width
>
TL_DEVICE
dim3
rasterization2DColumn
()
{
auto
ceil_div
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
};
const
unsigned
int
block_idx
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
;
const
unsigned
int
grid_size
=
gridDim
.
x
*
gridDim
.
y
;
const
unsigned
int
panel_size
=
panel_width
*
gridDim
.
y
;
const
unsigned
int
panel_offset
=
block_idx
%
panel_size
;
const
unsigned
int
panel_idx
=
block_idx
/
panel_size
;
const
unsigned
int
total_panel
=
ceil_div
(
grid_size
,
panel_size
);
const
unsigned
int
stride
=
panel_idx
+
1
<
total_panel
?
panel_width
:
(
grid_size
-
panel_idx
*
panel_size
)
/
gridDim
.
y
;
const
unsigned
int
row_idx
=
(
panel_idx
&
1
)
?
gridDim
.
y
-
1
-
panel_offset
/
stride
:
panel_offset
/
stride
;
const
unsigned
int
col_idx
=
panel_offset
%
stride
+
panel_idx
*
panel_width
;
return
{
col_idx
,
row_idx
,
blockIdx
.
z
};
}
}
// namespace tl
tilelang/contrib/hipcc.py
View file @
2c490782
...
@@ -61,7 +61,8 @@ def compile_hip(code,
...
@@ -61,7 +61,8 @@ def compile_hip(code,
file_target
=
path_target
if
path_target
else
temp_target
file_target
=
path_target
if
path_target
else
temp_target
cmd
=
[
"hipcc"
]
cmd
=
[
"hipcc"
]
cmd
+=
[
"-O3"
,
'-c'
]
cmd
+=
[
"-O1"
,
'-c'
]
cmd
+=
[
"-Wno-invalid-constexpr"
]
if
isinstance
(
arch
,
str
):
if
isinstance
(
arch
,
str
):
cmd
+=
[
f
"--offload-arch=
{
arch
}
"
]
cmd
+=
[
f
"--offload-arch=
{
arch
}
"
]
if
target_format
==
"hsaco"
:
if
target_format
==
"hsaco"
:
...
...
tilelang/contrib/rocm.py
View file @
2c490782
...
@@ -227,7 +227,7 @@ def have_matrixcore(compute_version=None):
...
@@ -227,7 +227,7 @@ def have_matrixcore(compute_version=None):
@
tvm
.
ffi
.
register_func
(
"tvm_callback_rocm_get_arch"
,
override
=
True
)
@
tvm
.
ffi
.
register_func
(
"tvm_callback_rocm_get_arch"
,
override
=
True
)
def
get_rocm_arch
(
rocm_path
=
"/opt/
rocm
"
):
def
get_rocm_arch
(
rocm_path
=
"/opt/
dtk
"
):
"""Utility function to get the AMD GPU architecture
"""Utility function to get the AMD GPU architecture
Parameters
Parameters
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment