Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
522b7aee
Commit
522b7aee
authored
Jan 30, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
ff936fd6
84832fc4
Changes
130
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1066 additions
and
36 deletions
+1066
-36
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
...wd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
+2
-3
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
+2
-3
example/35_splitK_gemm/run_splitK_gemm_example.inc
example/35_splitK_gemm/run_splitK_gemm_example.inc
+1
-1
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
+1
-1
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
..._grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
+2
-3
example/44_elementwise_permute/elementwise_permute.cpp
example/44_elementwise_permute/elementwise_permute.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_3d.cpp
example/44_elementwise_permute/elementwise_permute_3d.cpp
+9
-6
example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp
...le/44_elementwise_permute/elementwise_permute_4D_fp16.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp
...44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
...4_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp
...4_elementwise_permute/elementwise_permute_4D_fp16_row.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
...4_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp
...4_elementwise_permute/elementwise_permute_4D_fp32_row.cpp
+3
-0
example/48_pool3d_fwd/pool3d_fwd_common.hpp
example/48_pool3d_fwd/pool3d_fwd_common.hpp
+4
-0
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
+4
-0
include/ck/ck.hpp
include/ck/ck.hpp
+2
-2
include/ck/host_utility/hip_check_error.hpp
include/ck/host_utility/hip_check_error.hpp
+15
-13
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+2
-2
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+999
-0
No files found.
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include "common.hpp"
#include "common.hpp"
...
@@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
...
@@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_example.inc"
#include "run_grouped_conv_fwd_bias_relu_add_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
#endif
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
View file @
522b7aee
...
@@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
...
@@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
Gemm1
*/
*/
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
...
@@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
#endif
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_batched_gemm_gemm_example
(
argc
,
argv
)
?
0
:
1
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_batched_gemm_gemm_example
(
argc
,
argv
)
?
0
:
1
;
}
#endif
example/35_splitK_gemm/run_splitK_gemm_example.inc
View file @
522b7aee
...
@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
{
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
1
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
...
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
View file @
522b7aee
...
@@ -42,7 +42,7 @@ using AElementOp = PassThrough;
...
@@ -42,7 +42,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
// clang-format off
// clang-format off
...
...
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include <cstdlib>
#include <cstdlib>
#include <iostream>
#include <iostream>
...
@@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
...
@@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
#endif
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_grouped_conv_conv_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_grouped_conv_conv_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
#endif
example/44_elementwise_permute/elementwise_permute.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/44_elementwise_permute/elementwise_permute_3d.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
@@ -14,8 +17,8 @@
...
@@ -14,8 +17,8 @@
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
ADataType
=
F
16
;
using
ADataType
=
F
32
;
using
BDataType
=
F
16
;
using
BDataType
=
F
32
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceElementwisePermuteInstance
=
using
DeviceElementwisePermuteInstance
=
...
@@ -25,10 +28,10 @@ using DeviceElementwisePermuteInstance =
...
@@ -25,10 +28,10 @@ using DeviceElementwisePermuteInstance =
2
,
// NumDim_m, {N, C}
2
,
// NumDim_m, {N, C}
2
,
// NumDim_n, {H, W}
2
,
// NumDim_n, {H, W}
1
,
// NumDim_k, {D}
1
,
// NumDim_k, {D}
8
,
// MPerThread
4
,
// MPerThread
8
,
// NPerThread
4
,
// NPerThread
8
,
// KPerThread
4
,
// KPerThread
ck
::
Sequence
<
8
>
,
// InScalarPerVectorSeq
ck
::
Sequence
<
4
>
,
// InScalarPerVectorSeq
ck
::
Sequence
<
4
>>
;
// OutScalarPerVectorSeq
ck
::
Sequence
<
4
>>
;
// OutScalarPerVectorSeq
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
Functor
>
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
Functor
>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
#include <random>
#include <random>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/48_pool3d_fwd/pool3d_fwd_common.hpp
View file @
522b7aee
...
@@ -32,6 +32,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
...
@@ -32,6 +32,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
throw
std
::
runtime_error
(
"Pool3d_fwd: problem with layout. "
);
return
{
0
,
0
,
0
,
0
,
0
};
};
};
template
<
typename
TensorLayout
>
template
<
typename
TensorLayout
>
...
@@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
...
@@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
}
throw
std
::
runtime_error
(
"Pool3d_fwd: problem with layout. "
);
return
HostTensorDescriptor
({
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
};
};
template
<
typename
DevicePoolFwdInstance
,
template
<
typename
DevicePoolFwdInstance
,
...
...
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
View file @
522b7aee
...
@@ -26,6 +26,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
...
@@ -26,6 +26,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
throw
std
::
runtime_error
(
"Avgpool3d_bwd: problem with layout. "
);
return
{
0
,
0
,
0
,
0
,
0
};
};
};
template
<
typename
TensorLayout
>
template
<
typename
TensorLayout
>
...
@@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
...
@@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
}
throw
std
::
runtime_error
(
"Avgpool3d_bwd: problem with layout. "
);
return
HostTensorDescriptor
({
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
};
};
template
<
typename
DevicePoolBwdInstance
,
template
<
typename
DevicePoolBwdInstance
,
...
...
include/ck/ck.hpp
View file @
522b7aee
...
@@ -213,12 +213,12 @@
...
@@ -213,12 +213,12 @@
#define CK_WORKAROUND_SWDEV_388832 1
#define CK_WORKAROUND_SWDEV_388832 1
// flag to enable (1) or disable (0) the debugging output in some kernels
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG
1
#define DEBUG_LOG
0
// denorm test fix, required to work around dissue
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#define CK_WORKAROUND_DENORM_FIX 0
#el
if
#el
se
// enable only on MI200
// enable only on MI200
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX
#endif // CK_WORKAROUND_DENORM_FIX
...
...
include/ck/host_utility/hip_check_error.hpp
View file @
522b7aee
...
@@ -12,21 +12,23 @@ inline void hip_check_error(hipError_t x)
...
@@ -12,21 +12,23 @@ inline void hip_check_error(hipError_t x)
if
(
x
!=
hipSuccess
)
if
(
x
!=
hipSuccess
)
{
{
std
::
ostringstream
ss
;
std
::
ostringstream
ss
;
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
__FILE__
<<
": "
<<
__LINE__
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
"in function: "
<<
__func__
;
<<
"hip_check_error.hpp"
<<
": "
<<
__LINE__
<<
"in function: "
<<
__func__
;
throw
std
::
runtime_error
(
ss
.
str
());
throw
std
::
runtime_error
(
ss
.
str
());
}
}
}
}
#define HIP_CHECK_ERROR(retval_or_funcall) \
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
do \
{ \
{ \
hipError_t _tmpVal = retval_or_funcall; \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
if(_tmpVal != hipSuccess) \
{ \
{ \
std::ostringstream ostr; \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
ostr << "HIP Function Failed (" \
<< hipGetErrorString(_tmpVal); \
<< "hip_check_error.hpp" \
throw std::runtime_error(ostr.str()); \
<< "," << __LINE__ << ") " << hipGetErrorString(_tmpVal); \
} \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
} while(0)
include/ck/host_utility/kernel_launch.hpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -30,7 +30,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -30,7 +30,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
y
,
block_dim
.
z
);
block_dim
.
z
);
printf
(
"Warm up
1
time
\n
"
);
printf
(
"Warm up
%d
time
s
\n
"
,
stream_config
.
cold_niters_
);
#endif
#endif
// warm up
// warm up
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
...
...
include/ck/stream_config.hpp
View file @
522b7aee
...
@@ -11,6 +11,6 @@ struct StreamConfig
...
@@ -11,6 +11,6 @@ struct StreamConfig
hipStream_t
stream_id_
=
nullptr
;
hipStream_t
stream_id_
=
nullptr
;
bool
time_kernel_
=
false
;
bool
time_kernel_
=
false
;
int
log_level_
=
0
;
int
log_level_
=
0
;
int
cold_niters_
=
1
;
int
cold_niters_
=
5
;
int
nrepeat_
=
1
0
;
int
nrepeat_
=
5
0
;
};
};
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
0 → 100644
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
// Double LDS buffer
// Prefetech 2 stage
// Local prefetch 1 stage
namespace
ck
{
template
<
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
ABufferLoadWidth
,
index_t
BBufferLoadWidth
,
index_t
ALDSWriteWidth
,
index_t
BLDSWriteWidth
,
index_t
ALDSReadWidth
,
index_t
BLDSReadWidth
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
KPerXDL
>
struct
BlockwiseGemmXdlops_pipeline_hotloop_inst
{
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
WaveNumM
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
WaveNumN
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
ABufferLoadWidth
);
static
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
BBufferLoadWidth
);
static
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
ALDSWriteWidth
);
static
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
BLDSWriteWidth
);
static
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
ALDSReadWidth
);
static
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
BLDSReadWidth
);
static
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
static
constexpr
auto
Print
()
{
printf
(
" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d
\n
"
,
BlockSize
,
WaveSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
KPerXDL
);
printf
(
" A/B buffer load inst: %d, %d
\n
A/B LDS write inst: %d, %d
\n
A/B LDS read inst: "
"%d, %d
\n
C MFMA inst: %d
\n
"
,
A_Buffer_Load_Inst_Num
,
B_Buffer_Load_Inst_Num
,
A_LDS_Write_Inst_Num
,
B_LDS_Write_Inst_Num
,
A_LDS_Read_Inst_Num
,
B_LDS_Read_Inst_Num
,
C_MFMA_Inst_Num
);
}
};
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
TransposeC
=
false
,
index_t
AMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
FloatAB
,
TransposeC
>{}.
K0PerXdlops
,
index_t
BMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
FloatAB
,
TransposeC
>
{}.
K0PerXdlops
>
struct
BlockwiseGemmXdlops_pipeline_v4
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
A_K0
=
ATileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BTileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
ATileDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BTileDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
FloatAB
,
TransposeC
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
using
HotLoopInstList
=
BlockwiseGemmXdlops_pipeline_hotloop_inst
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
A_K1
,
B_K1
,
A_K1
,
B_K1
,
KPack
,
KPack
,
MRepeat
,
NRepeat
,
MPerXDL
,
NPerXDL
,
xdlops_gemm
.
KPerXdlops
>
;
static_assert
(
KPerThread
%
KPack
==
0
,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack"
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPack
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPack
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk4D
(
xdlops_i
,
blk_i
);
return
make_tuple
(
m0
,
n0
,
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
using
Tuple4
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmXdlops_pipeline_v4
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
// HotLoopInstList::Print();
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
N
,
M0
,
M1
,
M2
));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
__device__
static
constexpr
auto
HotLoopScheduler
()
{
// schedule
constexpr
auto
num_ds_read_inst
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
+
HotLoopInstList
::
B_LDS_Read_Inst_Num
;
constexpr
auto
num_ds_write_inst
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
+
HotLoopInstList
::
B_LDS_Write_Inst_Num
;
;
constexpr
auto
num_buffer_load_inst
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
+
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
;
constexpr
auto
num_mfma_inst
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
constexpr
auto
num_issue
=
num_buffer_load_inst
;
static_for
<
0
,
num_issue
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
/
num_buffer_load_inst
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
num_ds_write_inst
/
num_buffer_load_inst
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_inst
/
num_buffer_load_inst
-
3
,
0
);
// MFMA
});
}
template
<
index_t
stage
>
__device__
static
constexpr
auto
TailScheduler
()
{
}
template
<
>
__device__
static
constexpr
auto
TailScheduler
<
1
>
()
{
// schedule
constexpr
auto
num_ds_read_inst
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
+
HotLoopInstList
::
B_LDS_Read_Inst_Num
;
constexpr
auto
num_ds_write_inst
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
+
HotLoopInstList
::
B_LDS_Write_Inst_Num
;
;
constexpr
auto
num_mfma_inst
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
constexpr
auto
num_issue
=
num_ds_write_inst
;
static_for
<
0
,
num_issue
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
/
num_ds_write_inst
-
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_inst
/
num_ds_write_inst
-
3
,
0
);
// MFMA
});
}
template
<
>
__device__
static
constexpr
auto
TailScheduler
<
2
>
()
{
// schedule
constexpr
auto
num_ds_read_inst
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
+
HotLoopInstList
::
B_LDS_Read_Inst_Num
;
constexpr
auto
num_mfma_inst
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
constexpr
auto
num_issue
=
num_ds_read_inst
;
static_for
<
0
,
num_issue
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_inst
/
num_ds_read_inst
,
0
);
// MFMA
});
}
static
constexpr
AMmaTileDesc
a_block_desc_m0_m1_m2_k
;
static
constexpr
BMmaTileDesc
b_block_desc_n0_n1_n2_k
;
template
<
bool
HasMainLoop
,
index_t
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
{
__builtin_amdgcn_sched_barrier
(
0
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
StaticallyIndexedArray
<
decltype
(
a_thread_buf
),
Number
<
2
>
{}
>
a_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
// Inst List:
// ds_read_b128: 16
// ds_write_b128: 8
// buffer_load_dwordx4: 16
// v_mfma: 0
// -------------------------------------------------------------------------------------------
// Global prefetch 1th, Fill Ping LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I0
));
// Local prefetch 1th, Fill Ping Reg
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
I0
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
I0
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
I0
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
I0
));
});
});
});
// Global prefetch 2th, Fill Pong LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I1
));
// Global prefetch 3rd
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
// This hot loop has two legacy loopover, to implement the double local buffer strategy
do
{
// -------------------------------------------------------------------------------------------
using
PingP1
=
Number
<
0
>
;
using
PongP1
=
Number
<
1
>
;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
PongP1
{}),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
PongP1
{}));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
PongP1
{}),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
PongP1
{}));
});
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
PingP1
{}));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
PingP1
{}));
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
a_thread_bufs
[
PingP1
{}][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
b_thread_bufs
[
PingP1
{}][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
// -------------------------------------------------------------------------------------------
using
PingP2
=
Number
<
1
>
;
using
PongP2
=
Number
<
0
>
;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
PongP2
{}),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
PongP2
{}));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
PongP2
{}),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
PongP2
{}));
});
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
PingP2
{}));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
PingP2
{}));
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
a_thread_bufs
[
PingP2
{}][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
b_thread_bufs
[
PingP2
{}][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
2
;
}
while
(
i
<
(
num_loop
-
3
));
}
// tail
if
constexpr
(
TailNum
==
3
)
{
using
PingP1
=
Number
<
0
>
;
using
PongP1
=
Number
<
1
>
;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
PongP1
{}),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
PongP1
{}));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
PongP1
{}),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
PongP1
{}));
});
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
PingP1
{}));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
PingP1
{}));
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
a_thread_bufs
[
PingP1
{}][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
b_thread_bufs
[
PingP1
{}][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
TailScheduler
<
1
>
();
__builtin_amdgcn_sched_barrier
(
0
);
// -------------------------------------------------------------------------------------------
using
PingP2
=
Number
<
1
>
;
using
PongP2
=
Number
<
0
>
;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
PongP2
{}),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
PongP2
{}));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
PongP2
{}),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
PongP2
{}));
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
a_thread_bufs
[
PingP2
{}][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
b_thread_bufs
[
PingP2
{}][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
TailScheduler
<
2
>
();
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
a_thread_bufs
[
PongP2
{}][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
b_thread_bufs
[
PongP2
{}][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
// 64 v_mfma
__builtin_amdgcn_sched_group_barrier
(
0x008
,
64
,
0
);
// MFMA
__builtin_amdgcn_sched_barrier
(
0
);
}
else
if
constexpr
(
TailNum
==
2
)
{
using
PingP1
=
Number
<
0
>
;
using
PongP1
=
Number
<
1
>
;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
PongP1
{}),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
PongP1
{}));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
PongP1
{}),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
PongP1
{}));
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
a_thread_bufs
[
PingP1
{}][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
b_thread_bufs
[
PingP1
{}][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
TailScheduler
<
2
>
();
__builtin_amdgcn_sched_barrier
(
0
);
// -------------------------------------------------------------------------------------------
using
PingP2
=
Number
<
1
>
;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
a_thread_bufs
[
PingP2
{}][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
ik
)
=
b_thread_bufs
[
PingP2
{}][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
// 64 v_mfma
__builtin_amdgcn_sched_group_barrier
(
0x008
,
64
,
0
);
// MFMA
__builtin_amdgcn_sched_barrier
(
0
);
}
}
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPack
>
{}),
make_tuple
(
Number
<
KPack
>
{},
Number
<
KPack
*
MRepeat
*
KPack
>
{},
Number
<
MRepeat
*
KPack
>
{},
I1
));
// B[N0, N1, N2, KPack]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPack
>
{}),
make_tuple
(
Number
<
KPack
>
{},
Number
<
KPack
*
MRepeat
*
KPack
>
{},
Number
<
MRepeat
*
KPack
>
{},
I1
));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment