Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
15baccf2
Commit
15baccf2
authored
Jul 08, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
5029a5a4
a328df25
Changes
82
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4177 additions
and
212 deletions
+4177
-212
.pre-commit-config.yaml
.pre-commit-config.yaml
+0
-0
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+2
-0
example/01_gemm/README.md
example/01_gemm/README.md
+18
-0
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+67
-2
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
+48
-0
example/01_gemm/run_gemm_example_streamk_v2.inc
example/01_gemm/run_gemm_example_streamk_v2.inc
+298
-0
example/CMakeLists.txt
example/CMakeLists.txt
+2
-2
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+3
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+3
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
...ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
+44
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
+556
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+322
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
+2010
-0
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
+409
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+333
-185
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+3
-5
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+9
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+34
-11
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+13
-6
No files found.
.pre-commit-config.yaml
100644 → 100755
View file @
15baccf2
File mode changed from 100644 to 100755
example/01_gemm/CMakeLists.txt
View file @
15baccf2
...
@@ -22,6 +22,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
...
@@ -22,6 +22,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_example_executable
(
example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v2
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v2
)
add_example_executable
(
example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_streamk_v3
)
add_example_executable
(
example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v3
)
add_example_executable
(
example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp
)
...
...
example/01_gemm/README.md
View file @
15baccf2
...
@@ -7,3 +7,21 @@
...
@@ -7,3 +7,21 @@
#arg3: run kernel # of times (>1)
#arg3: run kernel # of times (>1)
./bin/example_gemm_xdl 0 1 5
./bin/example_gemm_xdl 0 1 5
```
```
# Instructions for ```example_gemm_xdl_fp16_streamk_v3```
## Run ```example_gemm_xdl_fp16_streamk_v3```
```
bash
arg1: verification
(
0
=
no,
1
=
yes
)
arg2: initialization
(
0
=
no init,
1
=
integer value,
2
=
decimal value
)
arg3:
time
kernel
(
0
=
no,
1
=
yes
)
arg4 to 9: M
(
256x
)
, N
(
128x
)
, K
(
32x
)
, StrideA, StrideB, StrideC
arg10: stream-k
select
(
-1
: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK
)
arg11: Grid_size
(
-1
for
max occupancy
)
bin/example_gemm_xdl_fp16_streamk_v3 1 2 1 3840 4096 4096 4096 4096 4096 1
-1
a_m_k: dim 2, lengths
{
3840, 4096
}
, strides
{
4096, 1
}
b_k_n: dim 2, lengths
{
4096, 4096
}
, strides
{
4096, 1
}
c_m_n: dim 2, lengths
{
3840, 4096
}
, strides
{
4096, 1
}
problem
{
M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:4032, NP:4096, KRead:4096, KP:4096, AK0:512, BK0:2048, MBlock: 18, NBlock: 16, Stream-K Selection:1, Grid size:-1
}
Perf: 0.292022 ms, 441.23 TFlops, 330.348 GB/s, DeviceGemmXdlUniversal<MNPadding, RRR> BlkSize: 256, BlkTile: 224x256x64, WaveTile: 16x16, WaveMap: 7x8, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2
```
example/01_gemm/common.hpp
View file @
15baccf2
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -45,6 +45,19 @@ struct ProblemSizeStreamK final
...
@@ -45,6 +45,19 @@ struct ProblemSizeStreamK final
ck
::
index_t
NumSKBlocks
=
-
1
;
ck
::
index_t
NumSKBlocks
=
-
1
;
};
};
struct
ProblemSizeStreamK_universal
final
{
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
Grid_size
=
-
1
;
// defaults to max occupancy
ck
::
index_t
Streamk_sel
=
1
;
// defaults to 1-tile SK
};
struct
ProblemSizeSplitK
final
struct
ProblemSizeSplitK
final
{
{
...
@@ -123,6 +136,57 @@ bool parse_cmd_args<ProblemSize>(int argc,
...
@@ -123,6 +136,57 @@ bool parse_cmd_args<ProblemSize>(int argc,
return
true
;
return
true
;
}
}
template
<
>
bool
parse_cmd_args
<
ProblemSizeStreamK_universal
>
(
int
argc
,
char
*
argv
[],
ProblemSizeStreamK_universal
&
problem_size
,
ExecutionConfig
&
config
)
{
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
>=
10
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
StrideA
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
StrideB
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
StrideC
=
std
::
stoi
(
argv
[
9
]);
if
(
argc
>=
11
)
{
problem_size
.
Streamk_sel
=
std
::
stoi
(
argv
[
10
]);
problem_size
.
Grid_size
=
std
::
stoi
(
argv
[
11
]);
}
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
<<
"arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
<<
"
\n
arg11: Grid_size(-1 for max occupancy)"
<<
std
::
endl
;
return
false
;
}
return
true
;
}
template
<
>
template
<
>
bool
parse_cmd_args
<
ProblemSizeStreamK
>
(
int
argc
,
bool
parse_cmd_args
<
ProblemSizeStreamK
>
(
int
argc
,
char
*
argv
[],
char
*
argv
[],
...
@@ -165,7 +229,8 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
...
@@ -165,7 +229,8 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
<<
std
::
endl
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
<<
"arg10: NumSKBlocks(optional)"
<<
std
::
endl
;
<<
"arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
<<
"
\n
arg11: Grid_size(-1 for max occupancy)"
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
...
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
0 → 100644
View file @
15baccf2
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// clang-format off
using
DeviceGemmV2_Streamk_Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
224
,
256
,
64
,
8
,
2
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
2
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example_streamk_v2.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_universal_streamk_example
(
argc
,
argv
);
}
example/01_gemm/run_gemm_example_streamk_v2.inc
0 → 100644
View file @
15baccf2
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1
e
-
1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5e-1
;
// 57344 and 49152 are acceptable
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
#endif
using
namespace
ck
::
literals
;
auto
M
=
problem_size
.
M
;
auto
N
=
problem_size
.
N
;
auto
K
=
problem_size
.
K
;
auto
StrideA
=
problem_size
.
StrideA
;
auto
StrideB
=
problem_size
.
StrideB
;
auto
StrideC
=
problem_size
.
StrideC
;
auto
Grid_size
=
problem_size
.
Grid_size
;
auto
Streamk_sel
=
problem_size
.
Streamk_sel
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1_
uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1_
uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is -1, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
static_cast
<
std
::
size_t
>
(
col
);
}
else
{
return
static_cast
<
std
::
size_t
>
(
row
);
}
}
else
return
static_cast
<
std
::
size_t
>
(
stride
);
};
auto
f_get_default_streamk_policy
=
[](
ck
::
index_t
streamk_sel
)
{
if
(
streamk_sel
==
-
1
)
{
return
static_cast
<
std
::
size_t
>
(
4
);
}
else
return
static_cast
<
std
::
size_t
>
(
streamk_sel
);
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
Streamk_sel
=
f_get_default_streamk_policy
(
Streamk_sel
);
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
switch
(
config
.
init_method
)
{
case
0
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
break
;
case
2
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
break
;
case
3
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
default
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem
a_m_k_device_buf
(
sizeof
(
KernelADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
KernelBDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
KernelCDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
const
Tensor
<
KernelADataType
>
a_m_k_converted
(
a_m_k
);
const
Tensor
<
KernelBDataType
>
b_k_n_converted
(
b_k_n
);
a_m_k_device_buf
.
ToDevice
(
a_m_k_converted
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_converted
.
mData
.
data
());
#else
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
#endif
DeviceMem
workspace
;
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmV2_Streamk_Instance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
float
ave_time
=
0
;
auto
argument
=
gemm
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#else
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#endif
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
Streamk_sel
,
Grid_size
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
1
});
#ifdef BUILD_INT4_EXAMPLE
Tensor
<
CDataType
>
c_m_n_device_result_converted
(
c_m_n_host_result
.
mDesc
);
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result_converted
.
mData
.
data
());
c_m_n_device_result
=
c_m_n_device_result_converted
.
CopyAsType
<
CDataType
>
();
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
#endif
}
if
(
config
.
time_kernel
)
{
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
return
pass
;
}
bool
run_gemm_universal_streamk_example
(
int
argc
,
char
*
argv
[])
{
ProblemSizeStreamK_universal
problem_size
;
ExecutionConfig
config
;
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
}
example/CMakeLists.txt
View file @
15baccf2
...
@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND NOT
GPU
_TARGETS MATCHES
"gfx12"
AND source MATCHES
"_wmma"
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx11"
AND NOT
EX
_TARGETS MATCHES
"gfx12"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
...
@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND NOT
GPU
_TARGETS MATCHES
"gfx12"
AND source MATCHES
"_wmma"
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx11"
AND NOT
EX
_TARGETS MATCHES
"gfx12"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
15baccf2
...
@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
...
@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
per_hdim_case
=
per_hdim_case
+
FMHA_BWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_BWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_BWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
per_dtypes
=
per_dtypes
+
FMHA_BWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
if
not
per_dtypes
:
# empty string we add some ignore to suppress warning in api
per_dtypes
+=
' (void)t ; (void)s ; (void)a;'
return
FMHA_BWD_KERNEL_HEADER
+
FMHA_BWD_API
.
format
(
F_dispatch
=
per_dtypes
)
return
FMHA_BWD_KERNEL_HEADER
+
FMHA_BWD_API
.
format
(
F_dispatch
=
per_dtypes
)
# GEMM0: Q@K=S^T
# GEMM0: Q@K=S^T
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
15baccf2
...
@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
...
@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
if
not
per_dtypes
:
# empty string we add some ignore to suppress warning in api
per_dtypes
+=
' (void)t ; (void)s ; (void)a;'
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_API
.
format
(
F_dispatch
=
per_dtypes
)
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
@
dataclass
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
15baccf2
...
@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
...
@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
if
not
per_dtypes
:
# empty string we add some ignore to suppress warning in api
per_dtypes
+=
' (void)t ; (void)s ; (void)a;'
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_SPLITKV_API
.
format
(
F_dispatch
=
per_dtypes
)
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_SPLITKV_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
@
dataclass
...
...
include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
0 → 100644
View file @
15baccf2
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemm_Streamk_V2
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
Streamk_sel
,
ck
::
index_t
Grid_size
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
0 → 100644
View file @
15baccf2
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGemm_Xdl_CShuffle_Streamk_V3
:
public
DeviceGemm_Streamk_V2
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_streamk_v3
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
float
ave_time
=
0
;
index_t
k_grain
=
KPerBlock
;
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
dim3
grid_dim
;
if
(
arg
.
Grid_size
<
0
)
{
int
occupancy
,
num_cu
;
hipError_t
rtn
;
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
);
hip_check_error
(
rtn
);
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
rtn
);
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
num_cu
=
dev_prop
.
multiProcessorCount
;
arg
.
Grid_size
=
num_cu
*
occupancy
;
grid_dim
=
arg
.
Grid_size
;
}
else
grid_dim
=
arg
.
Grid_size
;
if
(
stream_config
.
flush_cache
)
{
Argument
arg_
=
arg
;
ck
::
utility
::
RotatingMemWrapper
<
Argument
>
rotating_mem
(
arg_
,
stream_config
.
rotating_count
,
arg_
.
M
*
arg_
.
K
*
sizeof
(
ADataType
),
arg_
.
K
*
arg_
.
N
*
sizeof
(
BDataType
));
rotating_mem
.
Print
();
auto
run_flush_cache
=
[
&
]()
{
// flush icache
ck
::
utility
::
flush_icache
();
// rotating mem
rotating_mem
.
Next
();
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
run_flush_cache
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg_
);
}
else
{
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
}
};
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
else
{
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
streamk_sel
,
index_t
Grid_size
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
streamk_sel
,
Grid_size
};
// HS
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
streamk_sel
,
index_t
Grid_size
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
streamk_sel
,
Grid_size
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
BlockGemmPipelineScheduler
,
std
::
string
>
BlkGemmPipelineSchedulerToString
{
{
BlockGemmPipelineScheduler
::
Intrawave
,
"Intrawave"
},
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
// clang-format off
str
<<
"DeviceGemmXdlUniversal"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
std
::
string
(
CLayout
::
name
)[
0
]
<<
">"
<<
" BlkSize: "
<<
BlockSize
<<
", "
<<
"BlkTile: "
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
", "
<<
"WaveTile: "
<<
MPerXDL
<<
"x"
<<
NPerXDL
<<
", "
<<
"WaveMap: "
<<
MXdlPerWave
<<
"x"
<<
NXdlPerWave
<<
", "
<<
"VmemReadVec: "
<<
ABlockTransferSrcScalarPerVector
<<
"x"
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
"BlkGemmPipelineScheduler: "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
"BlkGemmPipelinePrefetchStages: "
<<
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
15baccf2
...
@@ -1404,4 +1404,326 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1404,4 +1404,326 @@ struct BlockToCTileMap_GemmStreamK
}
}
};
};
template
<
uint32_t
MPerBlock_
,
uint32_t
NPerBlock_
,
uint32_t
KPerBlock_
,
StreamKReductionStrategy
ReductionStrategy_
=
StreamKReductionStrategy
::
Atomic
,
uint32_t
TileSwizzleSubM_
=
8
,
index_t
GroupNum
=
8
,
index_t
M01_
=
4
>
struct
BlockToCTileMap_GemmStreamK_v2
{
static
constexpr
uint32_t
min_k_iters_per_sk_block
=
2
;
static
constexpr
uint32_t
MPerBlock
=
MPerBlock_
;
static
constexpr
uint32_t
NPerBlock
=
NPerBlock_
;
static
constexpr
uint32_t
KPerBlock
=
KPerBlock_
;
static
constexpr
StreamKReductionStrategy
ReductionStrategy
=
ReductionStrategy_
;
static
constexpr
uint32_t
tile_swizzle_sub_m
=
TileSwizzleSubM_
;
//--------------------------------------
// pass to device
mutable
uint32_t
sk_num_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
dp_start_block_idx
;
uint32_t
reduction_start_block_idx
;
uint32_t
k_iters_per_big_block
;
MDiv2
n_tiles
;
MDiv
k_iters_per_tile
;
MDiv
equiv_tiles_big
;
// for reduction
MDiv
equiv_tiles_little
;
// for reduction
// prefer construct on host
__host__
__device__
BlockToCTileMap_GemmStreamK_v2
(
uint32_t
m
,
uint32_t
n
,
uint32_t
k
,
uint32_t
grid_size
=
1
,
uint32_t
streamk_sel
=
1
)
{
// total output tiles
uint32_t
num_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
)
*
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
k_iters_per_tile
=
MDiv
(
math
::
integer_divide_ceil
(
k
,
KPerBlock
));
uint32_t
dp_tiles
,
dp_num_blocks
,
sk_total_iters
;
// default to regular DP GEMM if sk blocks == 0
if
(
streamk_sel
==
0
)
{
sk_num_blocks
=
0
;
dp_tiles
=
num_tiles
;
sk_num_big_blocks
=
0
;
k_iters_per_big_block
=
0
;
dp_num_blocks
=
num_tiles
;
// all tile to be dp block
dp_start_block_idx
=
0
;
sk_total_iters
=
0
;
// clear this tiles
}
// 2-tile sk + DP GEMM
else
{
// check if there's enough work for DP+ stream-k
bool
bigEnough
=
num_tiles
>
grid_size
;
// select between stream-k strategies
uint32_t
sk_tiles
=
0
;
if
(
streamk_sel
==
1
)
// 1 tile stream-k
{
sk_tiles
=
bigEnough
?
(
num_tiles
%
grid_size
)
:
num_tiles
;
}
else
if
(
streamk_sel
==
2
)
// 2-tile stream-k
{
sk_tiles
=
bigEnough
?
(
grid_size
+
num_tiles
%
grid_size
)
:
num_tiles
;
}
else
if
(
streamk_sel
==
3
)
// 3-tile stream-k
{
sk_tiles
=
(
num_tiles
>
(
2
*
grid_size
))
?
(
2
*
grid_size
+
num_tiles
%
grid_size
)
:
num_tiles
;
}
else
if
(
streamk_sel
==
4
)
// 4-tile stream-k
{
sk_tiles
=
(
num_tiles
>
(
3
*
grid_size
))
?
(
3
*
grid_size
+
num_tiles
%
grid_size
)
:
num_tiles
;
}
sk_num_blocks
=
sk_tiles
;
// remaining tiles are DP tiles
dp_tiles
=
bigEnough
?
(
num_tiles
-
sk_tiles
)
:
0
;
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t
k_iters_per_sk_block
=
sk_total_iters
/
sk_num_blocks
;
sk_num_big_blocks
=
sk_total_iters
-
k_iters_per_sk_block
*
sk_num_blocks
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
sk_num_blocks
;
}
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
// using multiple blocks for parallel reduction
reduction_start_block_idx
=
dp_start_block_idx
+
dp_num_blocks
;
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
uint32_t
upper_big
=
math
::
lcm
(
k_iters_per_big_block
,
k_iters_per_tile
.
get
());
uint32_t
upper_little
=
math
::
lcm
(
k_iters_per_big_block
-
1
,
k_iters_per_tile
.
get
());
equiv_tiles_big
=
MDiv
(
upper_big
/
k_iters_per_tile
.
get
());
equiv_tiles_little
=
MDiv
(
upper_little
/
k_iters_per_tile
.
get
());
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
;
}
__host__
__device__
uint32_t
get_sk_total_iters
()
const
{
uint32_t
sk_total_iters
=
sk_num_big_blocks
*
k_iters_per_big_block
+
(
sk_num_blocks
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
return
sk_total_iters
;
}
__host__
__device__
uint32_t
get_sk_tiles
()
const
{
// tiles for sk
uint32_t
sk_total_iters
=
get_sk_total_iters
();
return
k_iters_per_tile
.
div
(
sk_total_iters
);
}
__host__
__device__
index_t
get_grid_dims
()
const
{
if
constexpr
(
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
// return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
return
reduction_start_block_idx
+
get_sk_tiles
();
}
else
return
reduction_start_block_idx
;
}
__device__
uint32_t
get_block_idx
()
const
{
// TODO: swizzle block index for better locality
return
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
}
__device__
void
get_block_itr
(
uint32_t
block_idx
,
uint32_t
&
iter_start
,
uint32_t
&
iter_end
)
const
{
if
(
block_idx
<
sk_num_big_blocks
)
{
iter_start
=
block_idx
*
k_iters_per_big_block
;
iter_end
=
iter_start
+
k_iters_per_big_block
;
}
else
if
(
block_idx
<
sk_num_blocks
)
{
iter_start
=
(
sk_num_big_blocks
*
k_iters_per_big_block
)
+
(
block_idx
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
iter_end
=
iter_start
+
(
k_iters_per_big_block
-
1
);
}
else
if
(
block_idx
>=
dp_start_block_idx
)
{
uint32_t
sk_total_iters
=
get_sk_total_iters
();
uint32_t
dp_iters_per_block
=
k_iters_per_tile
.
get
();
iter_start
=
sk_total_iters
+
(
block_idx
-
dp_start_block_idx
)
*
dp_iters_per_block
;
iter_end
=
iter_start
+
dp_iters_per_block
;
}
}
__device__
uint32_t
get_current_iter_length
(
uint32_t
iter_start
,
uint32_t
iter_end
,
uint32_t
total_iter_length
)
const
{
uint32_t
iter_length_mod
,
iter_length_quo
/*unused*/
;
k_iters_per_tile
.
divmod
(
iter_end
,
iter_length_quo
,
iter_length_mod
);
uint32_t
current_iter_length
=
math
::
min
(
iter_length_mod
==
0
?
(
iter_end
-
iter_start
)
:
iter_length_mod
,
total_iter_length
);
return
current_iter_length
;
}
__device__
uint32_t
get_tile_idx
(
uint32_t
iter
)
const
{
return
k_iters_per_tile
.
div
(
iter
);
}
__device__
void
get_tile_idx_with_offset
(
uint32_t
iter
,
uint32_t
&
tile_idx
,
uint32_t
&
iter_offset
)
const
{
k_iters_per_tile
.
divmod
(
iter
,
tile_idx
,
iter_offset
);
}
__device__
auto
tile_to_spatial
(
uint32_t
tile_idx
,
uint32_t
m
,
uint32_t
n
)
const
{
uint32_t
m_tile_idx
,
n_tile_idx
;
uint32_t
n_tiles_value
=
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
n_tiles
.
divmod
(
tile_idx
,
n_tiles_value
,
m_tile_idx
,
n_tile_idx
);
// // swizzle tile
uint32_t
m_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
);
uint32_t
tile_swizzle_sub_m_rem
=
m_tiles
%
tile_swizzle_sub_m
;
const
auto
sub_m_adapt
=
(
m_tile_idx
<
(
m_tiles
-
tile_swizzle_sub_m_rem
))
?
tile_swizzle_sub_m
:
tile_swizzle_sub_m_rem
;
uint32_t
m_tile_idx_sub0
,
m_tile_idx_sub1
;
m_tile_idx_sub0
=
m_tile_idx
/
tile_swizzle_sub_m
;
m_tile_idx_sub1
=
m_tile_idx
%
tile_swizzle_sub_m
;
uint32_t
tile_idx_local
=
n_tile_idx
+
m_tile_idx_sub1
*
n_tiles_value
;
uint32_t
m_tile_idx_with_adapt
,
n_tile_idx_with_adapt
;
n_tile_idx_with_adapt
=
tile_idx_local
/
sub_m_adapt
;
m_tile_idx_with_adapt
=
tile_idx_local
%
sub_m_adapt
;
return
make_tuple
(
m_tile_idx_with_adapt
+
m_tile_idx_sub0
*
tile_swizzle_sub_m
,
n_tile_idx_with_adapt
);
}
__host__
__device__
uint32_t
get_workspace_size_for_acc
(
uint32_t
acc_element_bytes
)
const
{
static
constexpr
uint32_t
alignment
=
128
;
uint32_t
acc_buffer_bytes
=
MPerBlock
*
NPerBlock
*
get_total_acc_buffers
()
*
acc_element_bytes
;
return
(
acc_buffer_bytes
+
alignment
-
1
)
/
alignment
*
alignment
;
}
__host__
__device__
uint32_t
get_workspace_size_for_semaphore
()
const
{
return
get_sk_tiles
()
*
sizeof
(
uint32_t
);
}
__host__
__device__
uint32_t
get_workspace_size
(
uint32_t
acc_element_bytes
)
const
{
return
get_workspace_size_for_acc
(
acc_element_bytes
)
+
get_workspace_size_for_semaphore
();
}
__host__
__device__
uint32_t
get_tile_intersections
(
uint32_t
tiles_
,
const
MDiv
&
equiv_tiles_
)
const
{
uint32_t
tile_idx_
=
tiles_
==
0
?
0
:
(
tiles_
-
1
);
uint32_t
max_equiv_tiles_
=
equiv_tiles_
.
get
()
-
1
;
uint32_t
quo_
,
rem_
;
equiv_tiles_
.
divmod
(
tile_idx_
,
quo_
,
rem_
);
return
quo_
*
max_equiv_tiles_
+
rem_
;
}
__host__
__device__
uint32_t
get_tiles_cover_sk_block
(
uint32_t
num_sk_blocks_
,
uint32_t
iters_per_sk_block_
)
const
{
return
k_iters_per_tile
.
div
(
num_sk_blocks_
*
iters_per_sk_block_
+
k_iters_per_tile
.
get
()
-
1
);
}
__host__
__device__
uint32_t
get_total_acc_buffers
()
const
{
uint32_t
tiles_cover_big_blocks
=
get_tiles_cover_sk_block
(
sk_num_big_blocks
,
k_iters_per_big_block
);
uint32_t
tiles_cover_little_blocks
=
get_tiles_cover_sk_block
(
sk_num_blocks
-
sk_num_big_blocks
,
k_iters_per_big_block
-
1
);
uint32_t
total_intersec_big
=
get_tile_intersections
(
tiles_cover_big_blocks
,
equiv_tiles_big
);
uint32_t
total_intersec_little
=
get_tile_intersections
(
tiles_cover_little_blocks
,
equiv_tiles_little
);
return
sk_num_blocks
+
total_intersec_big
+
total_intersec_little
;
}
__device__
uint32_t
get_acc_buffer_offset_from_tile
(
uint32_t
tile_idx_
)
const
{
// TODO: from big to little
uint32_t
tiles_cover_big_blocks
=
get_tiles_cover_sk_block
(
sk_num_big_blocks
,
k_iters_per_big_block
);
if
(
tile_idx_
<
tiles_cover_big_blocks
)
{
uint32_t
touched_sk_blocks
=
(
tile_idx_
*
k_iters_per_tile
.
get
()
+
k_iters_per_big_block
-
1
)
/
k_iters_per_big_block
;
uint32_t
current_intersec
=
get_tile_intersections
(
tile_idx_
,
equiv_tiles_big
);
return
touched_sk_blocks
+
current_intersec
;
}
else
{
uint32_t
iters_per_little_sk_block
=
k_iters_per_big_block
-
1
;
uint32_t
tile_idx_little_reverse
=
get_sk_tiles
()
-
tile_idx_
;
uint32_t
touched_sk_blocks
=
(
tile_idx_little_reverse
*
k_iters_per_tile
.
get
()
+
iters_per_little_sk_block
-
1
)
/
iters_per_little_sk_block
;
uint32_t
current_intersec
=
get_tile_intersections
(
tile_idx_little_reverse
,
equiv_tiles_little
);
return
get_total_acc_buffers
()
-
(
touched_sk_blocks
+
current_intersec
);
}
}
__device__
uint32_t
get_acc_buffer_offset_from_block
(
uint32_t
block_idx_
)
const
{
uint32_t
iters_per_big_sk_block
=
k_iters_per_big_block
;
uint32_t
iters_per_little_sk_block
=
k_iters_per_big_block
-
1
;
if
(
block_idx_
<
sk_num_big_blocks
)
{
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_
*
iters_per_big_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
equiv_tiles_big
);
return
block_idx_
+
current_intersec
;
}
else
{
uint32_t
block_idx_little_reverse
=
sk_num_blocks
-
block_idx_
;
uint32_t
touched_tiles
=
k_iters_per_tile
.
div
(
block_idx_little_reverse
*
iters_per_little_sk_block
+
k_iters_per_tile
.
get
()
-
1
);
uint32_t
current_intersec
=
get_tile_intersections
(
touched_tiles
,
equiv_tiles_little
);
return
get_total_acc_buffers
()
-
(
block_idx_little_reverse
+
current_intersec
);
}
}
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
0 → 100644
View file @
15baccf2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
// kernel function Blockers:
// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
// two lds chunks.
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
// buffer when we declare __shared__ inside blkgemmpipe
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
}
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v4
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
GridwiseGemm_xdl_cshuffle_streamk_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
static
auto
CalculateKPadded
(
index_t
K
)
{
return
math
::
integer_divide_ceil
(
K
,
KPerBlock
)
*
KPerBlock
;
}
__host__
static
auto
CalculateAK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
AK1Value
);
}
__host__
static
auto
CalculateBK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
BK1Value
);
}
__host__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KPerBlock
;
}
__host__
static
auto
CalculateKRead
(
index_t
K
,
index_t
K_Batch
=
1
)
{
constexpr
auto
KReadVec
=
math
::
lcm
(
AK1Number
,
BK1Number
);
auto
K_t
=
K_Batch
*
KReadVec
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KReadVec
;
}
__host__
static
auto
CalculateMBlock
(
index_t
M
)
{
return
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNBlock
(
index_t
N
)
{
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
StrideB
,
I1
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
N
,
NPad
-
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeAMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
struct
Problem
{
__host__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
Streamk_sel_
,
index_t
Grid_size_
)
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
Streamk_sel
{
Streamk_sel_
},
Grid_size
{
Grid_size_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
KRead
{
CalculateKRead
(
K_
,
1
)},
KPadded
{
CalculateKPadded
(
K_
,
1
)},
AK0
{
CalculateAK0Padded
(
K_
,
1
)},
BK0
{
CalculateBK0Padded
(
K_
,
1
)},
MBlock
{
CalculateMBlock
(
M_
)},
NBlock
{
CalculateNBlock
(
N_
)}
{
}
__host__
void
Print
()
const
{
std
::
cout
<<
"problem {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KRead:"
<<
KRead
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
", "
<<
"MBlock: "
<<
MBlock
<<
", "
<<
"NBlock: "
<<
NBlock
<<
", Stream-K Selection:"
<<
Streamk_sel
<<
", Grid size:"
<<
Grid_size
<<
"}"
<<
std
::
endl
;
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
Streamk_sel
;
mutable
index_t
Grid_size
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KRead
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
index_t
MBlock
;
index_t
NBlock
;
};
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
Problem
{
__host__
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
Streamk_sel_
,
index_t
Grid_size_
)
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
Streamk_sel_
,
Grid_size_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
{
}
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
CDataType
*
p_c_grid
;
};
struct
SplitKBatchOffset
{
__device__
SplitKBatchOffset
(
Problem
&
problem
,
unsigned
int
kbatch_id
,
unsigned
int
orig_K
)
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
a_k_split_offset
=
kbatch_id
*
problem
.
KRead
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
kbatch_id
*
problem
.
KRead
*
problem
.
M
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
kbatch_id
*
problem
.
KRead
*
problem
.
N
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
b_k_split_offset
=
kbatch_id
*
problem
.
KRead
;
}
if
(
kbatch_id
<
static_cast
<
uint32_t
>
(
problem
.
KBatch
-
1
))
{
problem
.
K
=
problem
.
KRead
;
}
else
{
problem
.
K
=
orig_K
-
problem
.
KRead
*
(
problem
.
KBatch
-
1
);
}
}
index_t
a_k_split_offset
;
index_t
b_k_split_offset
;
};
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
+
ABlockLdsExtraM
>
{},
I1
));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
*
Number
<
MLdsLayer
>
{},
Number
<
MPerBlock
/
MLdsLayer
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_ak0_mldslayer_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0Number
,
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_ak0_mldslayer_m_ak1
,
make_tuple
(
make_pass_through_transform
(
AK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
}
else
// ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
M0
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I1
);
constexpr
auto
M1
=
MPerBlock
/
M0
;
constexpr
auto
KThreadWrite
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I0
);
constexpr
auto
K0PerThreadWrite
=
AK0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
MPerXdl
;
constexpr
auto
K0PerThreadRead
=
AK0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
AK1Number
*
M0
*
sizeof
(
ADataType
)
>
128
)
?
1
:
128
/
(
AK1Number
*
M0
*
sizeof
(
ADataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mpair<=n0
constexpr
auto
mpair
=
(
AK1Number
*
MPerXdl
*
sizeof
(
ADataType
)
>
128
)
?
1
:
((
128
/
(
AK1Number
*
MPerXdl
*
sizeof
(
ADataType
)))
>
M0
?
M0
:
128
/
(
AK1Number
*
MPerXdl
*
sizeof
(
ADataType
)));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
K0PerThreadWrite
>
{},
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{},
Number
<
mpair
>
{},
AK1Number
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
a_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
M1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
kfold
>
{},
Number
<
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
kfold
>
{},
Number
<
K0PerThreadWrite
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
M0
/
mpair
>
{},
Number
<
mpair
>
{},
Number
<
M1
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
4
,
2
>
{},
Sequence
<
5
,
6
,
3
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
}
}
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
KPerBlock
+
BBlockLdsExtraN
>
{},
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
);
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
KPerBlock
*
NLdsLayer
>
{},
I1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_bk0_nldslayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0Number
,
Number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_nldslayer_n_bk1
,
make_tuple
(
make_pass_through_transform
(
BK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_lds_block_desc_bk0_n_bk1
;
}
else
// RowMajor B
{
constexpr
auto
N0
=
BBlockTransferThreadClusterLengths_BK0_N_BK1
{}.
At
(
I1
);
constexpr
auto
N1
=
NPerBlock
/
N0
;
constexpr
auto
KThreadWrite
=
BBlockTransferThreadClusterLengths_BK0_N_BK1
{}.
At
(
I0
);
constexpr
auto
K0PerThreadWrite
=
BK0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
NPerXdl
;
constexpr
auto
K0PerThreadRead
=
BK0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
BK1Number
*
N0
*
sizeof
(
BDataType
)
>
128
)
?
1
:
128
/
(
BK1Number
*
N0
*
sizeof
(
BDataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=n0
constexpr
auto
npair
=
(
BK1Number
*
NPerXdl
*
sizeof
(
BDataType
)
>
128
)
?
1
:
((
128
/
(
BK1Number
*
NPerXdl
*
sizeof
(
BDataType
)))
>
N0
?
N0
:
128
/
(
BK1Number
*
NPerXdl
*
sizeof
(
BDataType
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
K0PerThreadWrite
>
{},
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{},
Number
<
npair
>
{},
BK1Number
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
kfold
>
{},
Number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
b_lds_block_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
kfold
>
{},
Number
<
K0PerThreadWrite
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
N0
/
npair
>
{},
Number
<
npair
>
{},
Number
<
N1
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
4
,
2
>
{},
Sequence
<
5
,
6
,
3
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_lds_block_desc_bk0_n_bk1
;
}
}
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
using
BlockwiseGemmPipe
=
remove_cvref_t
<
decltype
(
BlockGemmPipeline_Selector
<
BlkGemmPipelineVer
,
BlkGemmPipeSched
,
BlockSize
,
ADataType
,
BDataType
,
ComputeTypeA
,
AccDataType
,
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()),
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()),
decltype
(
MakeAMmaTileDescriptor_M0_M1_M2_K
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
())),
decltype
(
MakeBMmaTileDescriptor_N0_N1_N2_K
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
())),
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
())
>
;
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
BDataType
)),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
auto
K_t
=
KPerBlock
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
else
{
if
(
karg
.
K
<=
0
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
else
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
else
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
else
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
if
constexpr
(
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
" Grid size: "
<<
karg
.
Grid_size
<<
" > 1 is not support yet"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
karg
.
AK0
/
(
KPerBlock
/
AK1Value
);
if
constexpr
(
BlkGemmPipelineVer
!=
BlockGemmPipelineVersion
::
v1
)
{
if
(
num_k_loop
<=
BlockwiseGemmPipe
::
PrefetchStages
)
{
return
false
;
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockHasHotloop
(
num_loop
);
}
__host__
static
constexpr
TailNumber
CalculateKBlockLoopTailNum
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockLoopTailNum
(
num_loop
);
}
template
<
typename
CGridDesc
>
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
using
Block2CTileMap_streamk
=
BlockToCTileMap_GemmStreamK_v2
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
StreamKReductionStrategy
::
Atomic
,
8
,
4
>
;
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
Problem
&
problem
)
{
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
,
problem
.
Streamk_sel
);
uint32_t
iter_start
,
iter_end
;
bool
is_sk_block
,
is_dp_block
;
index_t
num_k_block_main_loop
;
for
(
auto
block_idx
=
get_block_1d_id
();
block_idx
<
block_2_ctile_map_streamk
.
get_grid_dims
();
block_idx
+=
gridDim
.
x
)
{
is_sk_block
=
static_cast
<
uint32_t
>
(
block_idx
)
<
block_2_ctile_map_streamk
.
sk_num_blocks
;
is_dp_block
=
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
dp_start_block_idx
&&
static_cast
<
uint32_t
>
(
block_idx
)
<
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
num_k_block_main_loop
=
iter_end
-
iter_start
;
while
(
true
)
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
block_2_ctile_map_streamk
.
get_current_iter_length
(
iter_start
,
iter_end
,
num_k_block_main_loop
));
uint32_t
tile_idx
,
iter_offset
;
block_2_ctile_map_streamk
.
get_tile_idx_with_offset
(
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
block_work_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
const
index_t
k0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
iter_offset
*
AK0Number
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
ADataType
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
k0_block_data_idx_on_grid
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
BDataType
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
k0_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
// Cast after lds
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
sizeof
(
BDataType
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per
// shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per
// shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
// CGlobalMemoryDataOperation, // DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
false
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
c_shuffle_block_copy_lds_to_global
.
SetSrcSliceOrigin
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
0
,
0
,
0
,
0
));
if
(
is_dp_block
)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
Set
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
else
if
(
is_sk_block
)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
// exit condition
iter_end
-=
current_iter_length
;
if
(
iter_end
<=
iter_start
)
break
;
// make sure next loop LDS is ready for use
block_sync_lds
();
}
}
}
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
Problem
&
problem
)
{
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
);
uint32_t
iter_start
,
iter_end
;
bool
is_sk_block
,
is_dp_block
;
//, is_padding_block; //, is_reduction_block;
index_t
num_k_block_main_loop
;
for
(
auto
block_idx
=
get_block_1d_id
();
block_idx
<
block_2_ctile_map_streamk
.
get_grid_dims
();
block_idx
+=
gridDim
.
x
)
{
is_sk_block
=
static_cast
<
uint32_t
>
(
block_idx
)
<
block_2_ctile_map_streamk
.
sk_num_blocks
;
is_dp_block
=
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
dp_start_block_idx
&&
static_cast
<
uint32_t
>
(
block_idx
)
<
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
num_k_block_main_loop
=
iter_end
-
iter_start
;
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
block_2_ctile_map_streamk
.
get_current_iter_length
(
iter_start
,
iter_end
,
num_k_block_main_loop
));
uint32_t
tile_idx
,
iter_offset
;
block_2_ctile_map_streamk
.
get_tile_idx_with_offset
(
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
block_work_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
const
index_t
k0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
iter_offset
*
AK0Number
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
ADataType
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
k0_block_data_idx_on_grid
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
BDataType
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
k0_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared_0
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BDataType
*>
(
p_shared_0
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
sizeof
(
BDataType
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared_1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BDataType
*>
(
p_shared_1
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
sizeof
(
BDataType
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
auto
b_block_bufs
=
make_tuple
(
b_block_buf_ping
,
b_block_buf_pong
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_bufs
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_bufs
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared_0
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per
// shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per
// shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
// CGlobalMemoryDataOperation, // DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
false
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
c_shuffle_block_copy_lds_to_global
.
SetSrcSliceOrigin
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
0
,
0
,
0
,
0
));
if
(
is_dp_block
)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
Set
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
else
if
(
is_sk_block
)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
}
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
0 → 100644
View file @
15baccf2
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_smfmac.hpp"
namespace
ck
{
enum
struct
SmfmacInstr
{
smfmac_f32_16x16x32f16
=
0
,
smfmac_f32_32x32x16f16
,
smfmac_f32_16x16x32bf16
,
smfmac_f32_32x32x16bf16
,
};
template
<
SmfmacInstr
instr
>
struct
smfmac_type
;
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_16x16x32f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_32x32x16f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_16x16x32bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_32x32x16bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
typename
additional_type
=
base_type
>
struct
SmfmacSelector
{
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
typename
additional_type_
=
base_type_
>
static
constexpr
auto
GetSmfmac
();
template
<
>
static
constexpr
auto
GetSmfmac
<
half_t
,
16
,
16
>
()
{
return
SmfmacInstr
::
smfmac_f32_16x16x32f16
;
}
template
<
>
static
constexpr
auto
GetSmfmac
<
half_t
,
32
,
32
>
()
{
return
SmfmacInstr
::
smfmac_f32_32x32x16f16
;
}
template
<
>
static
constexpr
auto
GetSmfmac
<
bhalf_t
,
16
,
16
>
()
{
return
SmfmacInstr
::
smfmac_f32_16x16x32bf16
;
}
template
<
>
static
constexpr
auto
GetSmfmac
<
bhalf_t
,
32
,
32
>
()
{
return
SmfmacInstr
::
smfmac_f32_32x32x16bf16
;
}
static
constexpr
auto
selected_smfmac
=
smfmac_type
<
GetSmfmac
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
__host__
__device__
constexpr
SmfmacSelector
()
{
static_assert
(
selected_smfmac
.
group_size
*
selected_smfmac
.
num_groups_per_blk
==
selected_smfmac
.
num_regs_per_blk
,
"wrong! num_regs_per_blk"
);
static_assert
(
selected_smfmac
.
num_threads_per_blk
==
selected_smfmac
.
n_per_blk
,
"n_per_blk != num_threads_per_blk"
);
static_assert
(
selected_smfmac
.
num_regs_per_blk
*
selected_smfmac
.
num_input_blks
==
selected_smfmac
.
m_per_blk
,
"m_per_blk != num_input_blks * num_regs_per_blk"
);
static_assert
(
selected_smfmac
.
num_output_blks
==
selected_smfmac
.
num_input_blks
||
selected_smfmac
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
static_assert
(
selected_smfmac
.
num_regs_per_blk
*
selected_smfmac
.
wave_size
==
selected_smfmac
.
m_per_blk
*
selected_smfmac
.
n_per_blk
,
"num_regs_per_blk incorrect"
);
static_assert
(
selected_smfmac
.
is_k_reduction
||
(
selected_smfmac
.
num_input_blks
==
selected_smfmac
.
num_output_blks
),
"is_k_reduction wrong!"
);
}
static
constexpr
index_t
GetKPerXdlops
()
{
return
(
selected_smfmac
.
is_k_reduction
?
selected_smfmac
.
num_input_blks
:
1
)
*
selected_smfmac
.
k_per_blk
;
}
static
constexpr
index_t
GetK1PerXdlops
()
{
return
selected_smfmac
.
k_per_blk
;
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
KPack
,
typename
additional_type
=
base_type
>
struct
SparseXdlopsGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
smfmac_instr
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
(
smfmac_instr
.
m_per_blk
*
smfmac_instr
.
n_per_blk
*
smfmac_instr
.
num_output_blks
);
}
__host__
__device__
constexpr
SparseXdlopsGemm
()
{
static_assert
(
NPerXdlops
==
16
||
NPerXdlops
==
32
,
"Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops"
);
static_assert
(
MPerXdlops
==
16
||
MPerXdlops
==
32
,
"Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops"
);
static_assert
(
KPack
%
smfmac_instr
.
k_per_blk
==
0
,
"KPack cannot be divided by k_per_blk"
);
}
// XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
{
const
auto
M0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
N0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
return
transform_tensor_descriptor
(
c_desc_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
Number
<
smfmac_instr
.
num_groups_per_blk
>
{},
Number
<
smfmac_instr
.
num_input_blks
>
{},
Number
<
smfmac_instr
.
group_size
>
{})),
make_pass_through_transform
(
Number
<
smfmac_instr
.
num_threads_per_blk
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
}
template
<
typename
CDesc_G_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_G_M0_N0_M1_N1_M2_N2
&
c_desc_g_m0_n0_m1_n1_m2_n2
)
{
const
auto
G
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
M0
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
N0
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
M1
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
const
auto
N1
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_g_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
G
),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
smfmac_instr
.
num_groups_per_blk
,
smfmac_instr
.
num_input_blks
,
smfmac_instr
.
group_size
)),
make_pass_through_transform
(
smfmac_instr
.
num_threads_per_blk
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{},
Sequence
<
8
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
smfmac_instr
.
wave_size
;
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
smfmac_instr
.
wave_size
;
}
template
<
class
FloatA
,
class
FloatB
,
class
Idx
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
const
Idx
&
idx
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
,
"base base_type must be half or bfloat16!"
);
static_for
<
0
,
KPack
/
smfmac_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
],
p_c_thread
);
});
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
smfmac_instr
.
wave_size
;
}
__device__
static
auto
GetBlkIdx
()
{
const
auto
laneId
=
GetLaneId
();
constexpr
auto
threadidx_to_blk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
smfmac_instr
.
num_input_blks
,
smfmac_instr
.
num_threads_per_blk
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
blk_idx
=
threadidx_to_blk_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
laneId
));
const
auto
blk_id
=
blk_idx
[
I1
];
const
auto
blk_td
=
blk_idx
[
I2
];
return
make_tuple
(
blk_id
,
blk_td
);
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
smfmac_instr
.
is_k_reduction
)
{
return
make_tuple
(
blk_id
,
blk_td
);
}
else
{
return
make_tuple
(
0
,
laneId
);
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
smfmac_instr
.
is_k_reduction
)
{
return
make_tuple
(
blk_id
,
blk_td
);
}
else
{
return
make_tuple
(
0
,
laneId
);
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
index_t
n_offset
=
blk_i
*
smfmac_instr
.
n_per_blk
+
blk_td
;
index_t
m_offset
=
xdlops_i
*
smfmac_instr
.
m_per_blk
+
blk_id
*
smfmac_instr
.
group_size
;
return
CIndex
{
m_offset
,
n_offset
};
}
__device__
static
CIndex4D
GetBeginOfThreadBlk4D
(
index_t
/* xdlops_i */
,
index_t
/* blk_i */
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
return
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
smfmac
=
SmfmacSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
{};
static
constexpr
auto
smfmac_instr
=
smfmac
.
selected_smfmac
;
static
constexpr
auto
KPerXdlops
=
smfmac
.
GetKPerXdlops
();
static
constexpr
auto
K1PerXdlops
=
smfmac
.
GetK1PerXdlops
();
static
constexpr
auto
K0PerXdlops
=
KPerXdlops
/
K1PerXdlops
;
__host__
__device__
static
constexpr
auto
GetCM0M1M2NThreadBlkLengths
()
{
return
make_tuple
(
Number
<
smfmac_instr
.
num_groups_per_blk
>
{},
I1
,
Number
<
smfmac_instr
.
group_size
>
{},
I1
);
}
};
}
// namespace ck
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
15baccf2
...
@@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
...
@@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
}
// namespace impl
}
// namespace impl
// TODO: glc/slc/...
// TODO: glc/slc/...
template
<
index_t
bytes
>
template
<
index_t
bytes
,
bool
pre_nop
=
false
>
struct
buffer_load
;
struct
buffer_load
;
#pragma clang diagnostic push
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx))
// (exp_vector_type(xxx))
template
<
>
template
<
bool
pre_nop
>
struct
buffer_load
<
16
>
struct
buffer_load
<
16
,
pre_nop
>
{
{
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
==
16
);
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
if
constexpr
(
pre_nop
)
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
asm
volatile
(
"s_nop 4
\n
"
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
"memory"
);
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
}
};
};
template
<
>
template
<
bool
pre_nop
>
struct
buffer_load
<
8
>
struct
buffer_load
<
8
,
pre_nop
>
{
{
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
==
8
);
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
8
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
8
,
T
>::
payload_t
;
asm
volatile
(
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
if
constexpr
(
pre_nop
)
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
asm
volatile
(
"s_nop 4
\n
"
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
"memory"
);
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
}
};
};
template
<
>
template
<
bool
pre_nop
>
struct
buffer_load
<
4
>
struct
buffer_load
<
4
,
pre_nop
>
{
{
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
==
4
);
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
4
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
4
,
T
>::
payload_t
;
asm
volatile
(
"buffer_load_dword %0, %1, %2, %3 offen offset:%4"
if
constexpr
(
pre_nop
)
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
asm
volatile
(
"s_nop 4
\n
"
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
:
"memory"
);
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
}
};
};
template
<
>
template
<
bool
pre_nop
>
struct
buffer_load
<
2
>
struct
buffer_load
<
2
,
pre_nop
>
{
{
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
asm
volatile
(
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
if
constexpr
(
pre_nop
)
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
asm
volatile
(
"s_nop 4
\n
"
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
:
"memory"
);
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
}
};
};
template
<
>
template
<
bool
pre_nop
>
struct
buffer_load
<
1
>
struct
buffer_load
<
1
,
pre_nop
>
{
{
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
==
4
);
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
1
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
1
,
T
>::
payload_t
;
asm
volatile
(
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
if
constexpr
(
pre_nop
)
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
asm
volatile
(
"s_nop 4
\n
"
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
:
"memory"
);
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
}
};
};
template
<
index_t
bytes
>
template
<
index_t
bytes
,
bool
pre_nop
=
false
>
struct
buffer_load_if
;
struct
buffer_load_if
;
template
<
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
16
>
struct
buffer_load_if
<
16
,
pre_nop
>
{
{
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
==
16
);
static_assert
(
sizeof
(
T
)
==
16
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
static_assert
(
sizeof
(
mbuf_t
)
==
sizeof
(
T
));
static_assert
(
sizeof
(
mbuf_t
)
==
sizeof
(
T
));
asm
volatile
(
if
constexpr
(
pre_nop
)
"v_cmpx_le_u32 exec, 1, %5
\n
"
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"s_mov_b64 exec %6"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3
\n
"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
"s_mov_b64 exec %5"
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"memory"
);
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
}
};
};
template
<
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
8
>
struct
buffer_load_if
<
8
,
pre_nop
>
{
{
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
==
8
);
static_assert
(
sizeof
(
T
)
==
8
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
8
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
8
,
T
>::
payload_t
;
asm
volatile
(
if
constexpr
(
pre_nop
)
"v_cmpx_le_u32 exec, 1, %5
\n
"
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"s_mov_b64 exec %6"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3
\n
"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
"s_mov_b64 exec %5"
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"memory"
);
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
}
};
};
template
<
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
4
>
struct
buffer_load_if
<
4
,
pre_nop
>
{
{
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
==
4
);
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
4
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
4
,
T
>::
payload_t
;
asm
volatile
(
if
constexpr
(
pre_nop
)
"v_cmpx_le_u32 exec, 1, %5
\n
"
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"s_mov_b64 exec %6"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3
\n
"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
"s_mov_b64 exec %5"
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"memory"
);
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
}
};
};
template
<
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
2
>
struct
buffer_load_if
<
2
,
pre_nop
>
{
{
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
==
4
);
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
asm
volatile
(
if
constexpr
(
pre_nop
)
"v_cmpx_le_u32 exec, 1, %5
\n
"
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"s_mov_b64 exec %6"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3
\n
"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
"s_mov_b64 exec %5"
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"memory"
);
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
}
};
};
template
<
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
1
>
struct
buffer_load_if
<
1
,
pre_nop
>
{
{
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
==
4
);
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
1
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
1
,
T
>::
payload_t
;
asm
volatile
(
if
constexpr
(
pre_nop
)
"v_cmpx_le_u32 exec, 1, %5
\n
"
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"s_mov_b64 exec %6"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3
\n
"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
"s_mov_b64 exec %5"
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"memory"
);
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
}
};
};
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
...
@@ -294,17 +379,16 @@ struct buffer_store<16>
...
@@ -294,17 +379,16 @@ struct buffer_store<16>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
index_t
/*flag*/
=
1
)
{
{
static_assert
(
sizeof
(
T
)
==
16
);
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
fp32x4_t
;
using
mbuf_t
=
fp32x4_t
;
asm
volatile
(
asm
volatile
(
"buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
:
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
:
"memory"
);
}
}
};
};
...
@@ -315,17 +399,16 @@ struct buffer_store<8>
...
@@ -315,17 +399,16 @@ struct buffer_store<8>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
index_t
/*flag*/
=
1
)
{
{
static_assert
(
sizeof
(
T
)
==
8
);
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
fp32x2_t
;
using
mbuf_t
=
fp32x2_t
;
asm
volatile
(
asm
volatile
(
"buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
:
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
:
"memory"
);
}
}
};
};
...
@@ -336,17 +419,16 @@ struct buffer_store<4>
...
@@ -336,17 +419,16 @@ struct buffer_store<4>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
index_t
/*flag*/
=
1
)
{
{
static_assert
(
sizeof
(
T
)
==
4
);
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
using
mbuf_t
=
float
;
asm
volatile
(
asm
volatile
(
"buffer_store_dword %0, %1, %2, 0 offen offset:%3"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4"
:
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
:
"memory"
);
}
}
};
};
...
@@ -357,17 +439,16 @@ struct buffer_store<2>
...
@@ -357,17 +439,16 @@ struct buffer_store<2>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
index_t
/*flag*/
=
1
)
{
{
static_assert
(
sizeof
(
T
)
==
2
);
static_assert
(
sizeof
(
T
)
==
2
);
using
mbuf_t
=
short
;
using
mbuf_t
=
short
;
asm
volatile
(
asm
volatile
(
"buffer_store_short %0, %1, %2, 0 offen offset:%3"
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
:
"memory"
);
}
}
};
};
...
@@ -378,17 +459,16 @@ struct buffer_store<1>
...
@@ -378,17 +459,16 @@ struct buffer_store<1>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
index_t
/*flag*/
=
1
)
{
{
static_assert
(
sizeof
(
T
)
==
4
);
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
using
mbuf_t
=
float
;
asm
volatile
(
asm
volatile
(
"buffer_store_byte %0, %1, %2, 0 offen offset:%3"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4"
:
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
:
"memory"
);
}
}
};
};
...
@@ -402,21 +482,20 @@ struct buffer_store_if<16>
...
@@ -402,21 +482,20 @@ struct buffer_store_if<16>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
index_t
flag
=
1
)
{
{
static_assert
(
sizeof
(
T
)
==
16
);
static_assert
(
sizeof
(
T
)
==
16
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
fp32x4_t
;
using
mbuf_t
=
fp32x4_t
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_dwordx4 %0, %1, %2,
%3
offen offset:%
4
\n
"
"buffer_store_dwordx4 %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
6
"
"s_mov_b64 exec %
5
"
:
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"v"
(
flag
),
"s"
(
save_exec
)
"s"
(
save_exec
)
...
@@ -431,7 +510,7 @@ struct buffer_store_if<8>
...
@@ -431,7 +510,7 @@ struct buffer_store_if<8>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
index_t
flag
=
1
)
{
{
...
@@ -439,14 +518,13 @@ struct buffer_store_if<8>
...
@@ -439,14 +518,13 @@ struct buffer_store_if<8>
auto
save_exec
=
__builtin_amdgcn_read_exec
();
auto
save_exec
=
__builtin_amdgcn_read_exec
();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using
mbuf_t
=
ext_vector_t
<
typename
T
::
value_type
,
T
::
size
()
>
;
using
mbuf_t
=
ext_vector_t
<
typename
T
::
value_type
,
T
::
size
()
>
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_dwordx2 %0, %1, %2,
%3
offen offset:%
4
\n
"
"buffer_store_dwordx2 %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
6
"
"s_mov_b64 exec %
5
"
:
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"v"
(
flag
),
"s"
(
save_exec
)
"s"
(
save_exec
)
...
@@ -461,21 +539,20 @@ struct buffer_store_if<4>
...
@@ -461,21 +539,20 @@ struct buffer_store_if<4>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
index_t
flag
=
1
)
{
{
static_assert
(
sizeof
(
T
)
==
4
);
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_dword %0, %1, %2,
%3
offen offset:%
4
\n
"
"buffer_store_dword %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
6
"
"s_mov_b64 exec %
5
"
:
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"v"
(
flag
),
"s"
(
save_exec
)
"s"
(
save_exec
)
...
@@ -490,21 +567,20 @@ struct buffer_store_if<2>
...
@@ -490,21 +567,20 @@ struct buffer_store_if<2>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
index_t
flag
=
1
)
{
{
static_assert
(
sizeof
(
T
)
==
2
);
static_assert
(
sizeof
(
T
)
==
2
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
short
;
using
mbuf_t
=
short
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_short %0, %1, %2,
%3
offen offset:%
4
\n
"
"buffer_store_short %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
6
"
"s_mov_b64 exec %
5
"
:
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"v"
(
flag
),
"s"
(
save_exec
)
"s"
(
save_exec
)
...
@@ -519,21 +595,20 @@ struct buffer_store_if<1>
...
@@ -519,21 +595,20 @@ struct buffer_store_if<1>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
index_t
flag
=
1
)
{
{
static_assert
(
sizeof
(
T
)
==
4
);
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_byte %0, %1, %2,
%3
offen offset:%
4
\n
"
"buffer_store_byte %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
6
"
"s_mov_b64 exec %
5
"
:
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"v"
(
flag
),
"s"
(
save_exec
)
"s"
(
save_exec
)
...
@@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
...
@@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int
soffset
,
// dst_wave_addr_offset
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
CK_TILE_DEVICE
void
async_buffer_load_dword
(
void
*
smem
,
template
<
bool
pre_nop
=
false
>
int32x4_t
rsrc
,
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
index_t
voffset
,
int32x4_t
rsrc
,
index_t
soffset
,
index_t
voffset
,
index_t
ioffset
/*max 0xFFF*/
,
index_t
/*soffset*/
,
index_t
/*flag*/
=
0
)
index_t
ioffset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
asm
volatile
(
"buffer_load_dword %1, %2, %3 offen offset:%4 lds"
if
constexpr
(
pre_nop
)
:
"=r"
(
smem
)
/*dummy dependency for smem*/
asm
volatile
(
"s_nop 4
\n
"
:
"v"
(
voffset
),
"s"
(
rsrc
),
"s"
(
soffset
),
"n"
(
ioffset
)
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
:
"memory"
);
:
"=r"
(
smem
)
/*dummy dependency for smem*/
:
"v"
(
voffset
),
"s"
(
rsrc
),
"n"
(
ioffset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
:
"=r"
(
smem
)
/*dummy dependency for smem*/
:
"v"
(
voffset
),
"s"
(
rsrc
),
"n"
(
ioffset
)
:
"memory"
);
}
}
CK_TILE_DEVICE
void
async_buffer_load_fence
(
index_t
cnt
=
0
)
CK_TILE_DEVICE
void
async_buffer_load_fence
(
index_t
cnt
=
0
)
...
@@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
...
@@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
template
<
typename
T
,
template
<
typename
T
,
index_t
N
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_load_raw_impl
(
thread_buffer
<
T
,
N
>&
dst
,
CK_TILE_DEVICE
void
amd_buffer_load_raw_impl
(
thread_buffer
<
T
,
N
>&
dst
,
int32x4_t
src_wave_buffer_resource
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
static_assert
(
bytes
==
1
||
bytes
==
2
||
bytes
==
4
||
bytes
==
8
||
bytes
==
16
,
static_assert
(
bytes
==
1
||
bytes
==
2
||
bytes
==
4
||
bytes
==
8
||
bytes
==
16
,
...
@@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
...
@@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
using
type
=
thread_buffer
<
T
,
N
>
;
using
type
=
thread_buffer
<
T
,
N
>
;
if
constexpr
(
oob_conditional_check
)
if
constexpr
(
oob_conditional_check
)
{
{
buffer_load_if
<
sizeof
(
type
)
>
{}(
buffer_load_if
<
sizeof
(
type
),
pre_nop
>
{}(
dst
,
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
,
bool_constant
<
pre_nop
>
{});
}
}
else
else
{
{
buffer_load
<
sizeof
(
type
)
>
{}(
buffer_load
<
sizeof
(
type
),
pre_nop
>
{}(
dst
,
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
,
bool_constant
<
pre_nop
>
{});
}
}
}
}
template
<
typename
T
,
template
<
typename
T
,
index_t
N
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_impl
(
T
*
smem
,
CK_TILE_DEVICE
void
amd_async_buffer_load_impl
(
T
*
smem
,
int32x4_t
src_wave_buffer_resource
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_immediate_addr_offset
=
0
)
index_t
src_immediate_addr_offset
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
static_assert
(
sizeof
(
T
)
*
N
==
4
,
"wrong! not implemented vector size"
);
static_assert
(
sizeof
(
T
)
*
N
==
4
,
"wrong! not implemented vector size"
);
async_buffer_load_dword
(
smem
,
async_buffer_load_dword_v
(
smem
,
src_wave_buffer_resource
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
);
src_immediate_addr_offset
,
0
,
bool_constant
<
pre_nop
>
{});
}
}
template
<
index_t
N
,
template
<
index_t
N
,
...
@@ -1909,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
...
@@ -1909,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
template
<
typename
T
,
template
<
typename
T
,
index_t
N
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
T
*
p_src_wave
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_thread_element_offset
,
index_t
src_element_space_size
,
index_t
src_element_space_size
,
index_t
is_valid_element
=
0
)
index_t
is_valid_element
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
{
const
int32x4_t
src_wave_buffer_resource
=
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
is_valid_element
);
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
is_valid_element
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
}
// unfortunately async copy can not make sure invalid data is zero inside LDS
// unfortunately async copy can not make sure invalid data is zero inside LDS
...
@@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
...
@@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
// buffer_load OOB still working.
// buffer_load OOB still working.
template
<
typename
T
,
template
<
typename
T
,
index_t
N
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob
(
T
*
smem
,
bool
pre_nop
=
false
>
const
T
*
p_src_wave
,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
index_t
src_thread_element_offset
,
const
T
*
p_src_wave
,
index_t
src_element_space_size
)
index_t
src_thread_element_offset
,
index_t
src_element_space_size
,
bool_constant
<
pre_nop
>
=
{})
{
{
const
int32x4_t
src_wave_buffer_resource
=
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
...
@@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
...
@@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
);
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
bool_constant
<
pre_nop
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
}
}
// buffer_store requires:
// buffer_store requires:
...
...
include/ck_tile/core/arch/arch.hpp
View file @
15baccf2
...
@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
...
@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
"
::
);
"
::
);
}
}
CK_TILE_DEVICE
void
s_nop
()
CK_TILE_DEVICE
void
s_nop
(
index_t
cnt
=
0
)
{
{
#if 1
#if 1
asm
volatile
(
"\
asm
volatile
(
"s_nop %0"
:
:
"n"
(
cnt
)
:
);
s_nop 0
\n
\
"
::
);
#else
#else
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
cnt
);
#endif
#endif
}
}
...
...
include/ck_tile/core/config.hpp
View file @
15baccf2
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#define __gfx12__
#define __gfx12__
#endif
#endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
...
@@ -147,6 +148,14 @@
...
@@ -147,6 +148,14 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#define CK_TILE_DEBUG_LOG 0
#endif
#endif
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
15baccf2
...
@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
...
@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
generic
;
return
address_space_enum
::
generic
;
...
@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global,
...
@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global,
T
*
p_data_
=
nullptr
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
BufferSizeType
buffer_size_
;
int32x4_t
cached_buf_res_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
:
p_data_
{},
buffer_size_
{},
cached_buf_res_
{
0
},
invalid_element_value_
{}
{
{
}
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
0
}
{
{
}
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
invalid_element_value
}
{
{
}
}
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE
void
init_raw
()
{
cached_buf_res_
=
make_wave_buffer_resource
(
p_data_
,
buffer_size_
*
sizeof
(
type
));
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
global
;
return
address_space_enum
::
global
;
...
@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global,
...
@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
)
const
index_t
i
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
{
{
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global,
...
@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
p_data_
,
i
,
buffer_size_
,
is_valid_element
);
dst
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
pre_nop
>
{}
);
}
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
CK_TILE_DEVICE
constexpr
auto
async_get_raw
(
remove_cvref_t
<
T
>*
smem
,
async_get
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
)
const
index_t
i
,
bool
/*is_valid_element*/
,
bool_constant
<
pre_nop
>
=
{})
const
{
{
// X is vector of T
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global,
...
@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
amd_async_buffer_load_with_oob
_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
p_data_
,
i
,
buffer_size_
);
smem
,
cached_buf_res_
,
i
,
bool_constant
<
pre_nop
>
{}
);
}
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
...
@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
...
@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
lds
;
return
address_space_enum
::
lds
;
...
@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
...
@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
vgpr
;
return
address_space_enum
::
vgpr
;
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
15baccf2
...
@@ -36,30 +36,37 @@ template <typename T,
...
@@ -36,30 +36,37 @@ template <typename T,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{});
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{}
,
bool_constant
<
pre_nop
>
{}
);
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
>
index_t
NumCoord
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
)
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
{
return
tile_window
.
async_load
(
lds_tile
);
return
tile_window
.
async_load_raw
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment