Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9de63596
Unverified
Commit
9de63596
authored
Mar 11, 2024
by
Illia Silin
Committed by
GitHub
Mar 11, 2024
Browse files
Merge pull request #48 from ROCm/merge_from_public
Merge from public
parents
e60c5aea
6ac1d6a2
Changes
303
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4169 additions
and
1012 deletions
+4169
-1012
example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp
...e/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp
example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp
example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp
...ple/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp
...le/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp
example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp
+11
-0
example/64_fpAintB_gemm/CMakeLists.txt
example/64_fpAintB_gemm/CMakeLists.txt
+5
-0
example/64_fpAintB_gemm/common.hpp
example/64_fpAintB_gemm/common.hpp
+123
-0
example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp
example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp
+93
-0
example/64_fpAintB_gemm/run_gemm_example.inc
example/64_fpAintB_gemm/run_gemm_example.inc
+172
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+320
-641
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+42
-27
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
...block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
+223
-0
include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp
...e/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp
+46
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+198
-127
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+1729
-0
include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp
...eration/gpu/device/impl/device_elementwise_scale_impl.hpp
+14
-1
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+714
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+187
-152
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+237
-64
No files found.
example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
LeakyRelu
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
Power
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
Sigmoid
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
SoftRelu
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
TanH
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/64_fpAintB_gemm/CMakeLists.txt
0 → 100644
View file @
9de63596
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_custom_target
(
example_fpAintB_gemm_wmma
)
add_example_executable
(
example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp
)
add_dependencies
(
example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma
)
endif
()
example/64_fpAintB_gemm/common.hpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp"
struct
ProblemSize
final
{
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
};
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
};
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
template
<
typename
IntType
>
struct
UnsignedWeightPreprocessor
{
};
template
<
>
struct
UnsignedWeightPreprocessor
<
int8_t
>
{
using
UnsignedWeight
=
Tensor
<
uint8_t
>
;
using
SignedWeight
=
Tensor
<
int8_t
>
;
static
UnsignedWeight
convert
(
SignedWeight
const
&
Input
)
{
UnsignedWeight
Output
=
Input
.
template
CopyAsType
<
uint8_t
>();
auto
f_kn
=
[
&
](
auto
k
,
auto
n
)
{
const
uint8_t
adder
=
128
;
int8_t
v_signed_weight
;
uint8_t
v_unsigned_weight
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_signed_weight
,
Input
(
k
,
n
));
v_unsigned_weight
=
ck
::
type_convert
<
uint8_t
>
(
v_signed_weight
)
+
adder
;
Output
(
k
,
n
)
=
v_unsigned_weight
;
};
make_ParallelTensorFunctor
(
f_kn
,
Input
.
mDesc
.
GetLengths
()[
0
],
Input
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
return
Output
;
}
UnsignedWeight
operator
()(
SignedWeight
const
&
Input
)
{
return
convert
(
Input
);
}
};
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ProblemSize
&
problem_size
,
ExecutionConfig
&
config
)
{
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
10
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
StrideA
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
StrideB
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
StrideC
=
std
::
stoi
(
argv
[
9
]);
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
;
return
false
;
}
return
true
;
}
example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp"
// Implementation follows the paper:
// Kim, Young Jin, Rawn Henry, Raffy Fahim, and Hany Hassan Awadalla. “Who Says Elephants Can’t Run:
// Bringing Large Scale MoE Models into Cloud Scale Production.” arXiv, November 17, 2022.
// https://doi.org/10.48550/arXiv.2211.10017. Assume weight (Matrix B) is add preprocess to
// unsigned.
// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType
// The HostRef is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType
// TODO: Current implementation consume more VGPR than expected.
using
ADataType
=
ck
::
half_t
;
using
QuantDataType
=
int8_t
;
using
BDataType
=
uint8_t
;
using
ScaleDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceFpAintBGemm_Wmma_CShuffle
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
ScaleDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
// Prefetch stage
128
,
// BlockSize
64
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
8
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
32
,
1
,
4
>
,
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferencefpAintBGemm
<
ADataType
,
QuantDataType
,
ScaleDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/64_fpAintB_gemm/run_gemm_example.inc
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool
run_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
#endif
using
namespace
ck
::
literals
;
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
]
=
problem_size
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1_
uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1_
uz
,
stride
});
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
QuantDataType
>
quant_b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
// assume scale tensor is [1, n]
Tensor
<
ScaleDataType
>
scale_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
0
,
Row
{}));
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
QuantDataType
>
{
-
1.
f
,
1.
f
}(
quant_b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ScaleDataType
>
{
-
1.
f
,
1.
f
}(
scale_k_n
);
break
;
case
2
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
QuantDataType
>
{
-
1.
f
,
1.
f
}(
quant_b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
ScaleDataType
>
{
-
1.
f
,
1.
f
}(
scale_k_n
);
break
;
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
QuantDataType
>
{
-
1.
f
,
1.
f
}(
quant_b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
ScaleDataType
>
{
-
1.
f
,
1.
f
}(
scale_k_n
);
}
UnsignedWeightPreprocessor
<
QuantDataType
>
preprocessor
;
Tensor
<
BDataType
>
b_k_n
=
preprocessor
(
quant_b_k_n
);
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"scale_k_n: "
<<
scale_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem
a_m_k_device_buf
(
sizeof
(
KernelADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
KernelBDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
KernelCDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
const
Tensor
<
KernelADataType
>
a_m_k_converted
(
a_m_k
);
const
Tensor
<
KernelBDataType
>
b_k_n_converted
(
b_k_n
);
a_m_k_device_buf
.
ToDevice
(
a_m_k_converted
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_converted
.
mData
.
data
());
#else
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
scale_k_n_device_buf
(
sizeof
(
ScaleDataType
)
*
scale_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
scale_k_n_device_buf
.
ToDevice
(
scale_k_n
.
mData
.
data
());
#endif
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#else
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ScaleDataType
*>
(
scale_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#endif
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
if
(
config
.
do_verification
)
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
quant_b_k_n
,
scale_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
#ifdef BUILD_INT4_EXAMPLE
Tensor
<
CDataType
>
c_m_n_device_result_converted
(
c_m_n_host_result
.
mDesc
);
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result_converted
.
mData
.
data
());
c_m_n_device_result
=
c_m_n_device_result_converted
.
CopyAsType
<
CDataType
>
();
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#endif
}
return
true
;
}
bool
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
}
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
9de63596
...
...
@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#define CK_MNK_LOOP
...
...
@@ -16,25 +17,45 @@ template <index_t BlockSize,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
typename
ABlockDesc
,
typename
BBlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
/* A: K0PerBlock x MPerBlock x K1
index_t
KPack
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
bool
TransposeC
=
false
>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct
BlockwiseGemmWMMA
_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
struct
BlockwiseGemmWMMA
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -42,18 +63,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static
constexpr
index_t
WaveSize
=
32
;
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
KPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static
constexpr
index_t
A_KRow
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow
=
BEnableLds
?
1
:
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
...
...
@@ -79,26 +98,39 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
if
constexpr
(
AEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
WMMA_a_idx
,
0
);
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
0
,
WMMA_a_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
if
constexpr
(
BEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
WMMA_b_idx
,
0
);
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
0
,
WMMA_b_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
template
<
index_t
m0
,
index_t
n0
>
...
...
@@ -129,10 +161,26 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
()
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex7D
(
Number
<
m0
>
,
Number
<
n0
>
)
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
waveId_m
,
blk_idx
[
I0
],
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
using
Tuple6
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmWMMA
(
Tuple6
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple6
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
...
...
@@ -143,32 +191,68 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
"wrong!"
);
}
//
Thread level, register decriptor. Vector-write
//
transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_N
ThreadPer
SubGroup_
M
AccVgprs
()
GetCThreadDescriptor_MRepeat_MWave_M
ThreadPer
SubGroup_NRepeat_NWave_NSubGroup_
N
AccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I0
];
constexpr
auto
NThreadPerSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I1
];
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
// Thread level, register decriptor. Vector-write
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
AccStride
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
MSubGroup
,
Number
<
NRepeat
>
{},
I1
,
NThreadPerSubGroup
,
MAccVgprs
));
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
MAccVgprs
),
make_tuple
(
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
AccStride
));
}
// Provide dimension size
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
...
...
@@ -179,37 +263,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_N
ThreadPer
SubGroup_
M
AccVgprs
(
.
MakeCDesc_MBlockxRepeat_MWave_M
ThreadPer
SubGroup_NBlockxRepeat_NWave_NSubGroup_
N
AccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
()
// Provide dimension size
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
auto
a_block_desc_k0_m0_m1_m2_k1
=
MakeABlockDescriptor_K0_M0_M1_M2_K1
()
;
static
constexpr
auto
b_block_desc_k0_n0_n1_n2_k1
=
MakeBBlockDescriptor_K0_N0_N1_N2_K1
()
;
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
...
...
@@ -224,36 +302,48 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
static_for
<
0
,
KPerBlock
/
KPack
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K
1
>
{},
n0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_K
Row
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
vector_type
<
FloatA
,
KPack
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
i
%
A_K1
))
>
{}];
make_tuple
(
i
/
A_K1
/
A_KRow
,
m0
,
0
,
(
i
/
A_K1
)
%
A_KRow
,
0
,
i
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
0
,
(
i
/
B_K1
)
%
B_KRow
,
0
,
i
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
...
...
@@ -263,8 +353,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()
(
Number
<
0
>{})
,
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()
(
Number
<
0
>
{})
,
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
...
...
@@ -272,583 +362,172 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
}
else
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KPerBlock
/
KPack
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
i
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
vector_type
<
FloatA
,
KPack
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
0
,
(
i
/
B_K1
)
%
B_KRow
,
0
,
i
%
B_K1
))
>
{}];
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
/
A_KRow
,
m0
,
0
,
(
i
/
A_K1
)
%
A_KRow
,
0
,
i
%
A_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
}
protected:
// A[K0, M0, M1, M2, K1]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
A_K1
>
{}));
// B[K0, N0, N1, N2, K1]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
B_K1
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
B_K1
>
{}));
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
Number
<
A_KRow
>
{},
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
A_K1
*
A_KRow
>
{},
Number
<
KPack
>
{},
Number
<
A_K1
*
A_KRow
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
*
B_KRow
>
{},
Number
<
KPack
>
{},
Number
<
B_K1
*
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
WmmaK
/
A_K1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
WmmaK
/
B_K1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
};
// block wise level pipe designed for inline asm
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
/* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*/
struct
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static
constexpr
index_t
WaveSize
=
32
;
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
KPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
wmma_gemm
.
GetRegSizePerWmma
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
WMMA_a_idx
,
0
);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
WMMA_b_idx
,
0
);
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk
();
constexpr
auto
mrepeat_mwave_mperWMMA_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperWMMA_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperWMMA_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperWMMA_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
}
// Thread level, register decriptor. Vector-write
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I0
];
constexpr
auto
NThreadPerSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I1
];
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
MSubGroup
,
Number
<
NRepeat
>
{},
I1
,
NThreadPerSubGroup
,
MAccVgprs
));
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
template
<
bool
EnableLds
>
struct
AThreadCopySelector
;
const
auto
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// Provide dimension size
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
template
<
>
struct
AThreadCopySelector
<
true
>
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{}
,
Number
<
MWaves
>
{}
,
Number
<
MPerWMMA
>
{}
,
Number
<
NRepeat
>
{}
,
Number
<
NWaves
>
{}
,
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
)
;
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
()
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
)
,
decltype
(
a_thread_desc_
)
,
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
A_KRow
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
A_K1
>
;
}
;
template
<
>
struct
AThreadCopySelector
<
false
>
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
()
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
0x76543210
,
0xfedcba98
,
TransposeC
?
false
:
true
>
;
};
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
true
>
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
auto
a_block_desc_k0_m0_m1_m2_k1
=
MakeABlockDescriptor_K0_M0_M1_M2_K1
();
static
constexpr
auto
b_block_desc_k0_n0_n1_n2_k1
=
MakeBBlockDescriptor_K0_N0_N1_N2_K1
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
B_KRow
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
false
>
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
RepeatDiff
=
MRepeat
-
NRepeat
;
// Read all Mrepeat, Nrepeat
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
I0
,
Number
<
iN
>
{},
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
Number
<
iN
>
{},
I0
,
I0
,
I0
),
b_thread_buf
);
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
I0
,
Number
<
iM
>
{},
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
Number
<
iM
>
{},
I0
,
I0
,
I0
),
a_thread_buf
);
});
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for
<
0
,
RepeatDiff
,
1
>
{}([
&
](
auto
iCut
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
iCut
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
iN
,
0
));
// s_nop();
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
// s_nop();
});
if
constexpr
(
KPerBlock
>
WmmaK
)
{
// Read Consumed Next inner loop A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
a_thread_buf
);
}
});
static_for
<
WmmaK
,
KPerBlock
,
WmmaK
>
{}([
&
](
auto
iWmmaK
)
{
// Stage 2: Run FIFO fashion loopover in Square
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
WmmaInnerloop
)
{
// Row Repeatation
static_for
<
WmmaInnerloop
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
WmmaInnerloop
+
RepeatDiff
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
iN
,
0
));
// s_nop();
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
// s_nop();
});
// Read Consumed Next inner loop A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
iWmmaK
/
A_K1
>
{},
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
a_thread_buf
);
// Col Repeatation
static_for
<
WmmaInnerloop
+
1
+
RepeatDiff
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
iM
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
WmmaInnerloop
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iM
,
WmmaInnerloop
,
0
));
// s_nop();
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
// s_nop();
});
// Read Consumed Next inner loop B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
iWmmaK
/
B_K1
>
{},
Number
<
WmmaInnerloop
>
{},
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
Number
<
WmmaInnerloop
>
{},
I0
,
I0
,
I0
),
b_thread_buf
);
});
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for
<
0
,
RepeatDiff
,
1
>
{}([
&
](
auto
iCut
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
iCut
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
iN
,
0
));
// s_nop();
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
// s_nop();
});
if
constexpr
(
KPerBlock
>
WmmaK
)
{
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
(
iWmmaK
+
WmmaK
)
/
A_K1
>
{},
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
a_thread_buf
);
}
});
});
// Stage 2: Run FIFO fashion loopover in Square
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
WmmaInnerloop
)
{
// Row Repeatation
static_for
<
WmmaInnerloop
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
WmmaInnerloop
+
RepeatDiff
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
iN
,
0
));
// s_nop();
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
// s_nop();
});
// Col Repeatation
static_for
<
WmmaInnerloop
+
1
+
RepeatDiff
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
iM
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
WmmaInnerloop
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iM
,
WmmaInnerloop
,
0
));
// s_nop();
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
// s_nop();
});
});
}
protected:
// A[M0, M1, M2, K0 = WmmaK]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
A_K1
>
{}));
// B[N0, N1, N2, K0 = WmmaK]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
B_K1
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
B_K1
>
{}));
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
WmmaK
/
A_K1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
WmmaK
/
B_K1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
0x76543210
,
0xfedcba98
,
TransposeC
?
true
:
false
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
9de63596
...
...
@@ -37,7 +37,9 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
,
typename
ComputeTypeA
=
FloatA
,
typename
ComputeTypeB
=
FloatB
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatA
,
MPerXDL
,
NPerXDL
,
KPack
,
FloatB
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
ComputeTypeA
,
MPerXDL
,
NPerXDL
,
KPack
,
ComputeTypeB
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
...
...
@@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
A
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeType
A
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
B
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeType
B
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
...
@@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
Float
A
,
KPack
>
a_thread_vec
;
vector_type
<
Float
B
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeType
A
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeType
B
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
Float
A
>()(
i
)
=
a_thread_buf
a_thread_vec
.
template
AsType
<
ComputeType
A
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
Float
B
>()(
i
)
=
b_thread_buf
b_thread_vec
.
template
AsType
<
ComputeType
B
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
});
using
mfma_input_type_a
=
typename
vector_type
<
Float
A
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeType
A
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
using
mfma_input_type_b
=
typename
vector_type
<
Float
B
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeType
B
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
...
@@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
Float
A
,
ComputeType
A
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
...
...
@@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
Float
B
,
ComputeType
B
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
...
...
@@ -398,6 +401,8 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
typename
ComputeTypeA
=
FloatA
,
typename
ComputeTypeB
=
FloatB
,
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
>
struct
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
:
public
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
...
...
@@ -410,7 +415,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
KPack
,
ComputeTypeA
,
ComputeTypeB
>
{
using
Base
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatA
,
...
...
@@ -422,7 +429,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
KPack
,
ComputeTypeA
,
ComputeTypeB
>
;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using
Base
::
a_block_desc_m0_m1_m2_k
;
...
...
@@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
A
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeType
A
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
B
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeType
B
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerThread
,
KPerInnerLoop
>
{}([
&
](
auto
k
)
{
...
...
@@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
Float
A
,
KPack
>
a_thread_vec
;
vector_type
<
Float
B
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeType
A
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeType
B
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
Float
A
>()(
i
)
=
a_thread_vec
.
template
AsType
<
ComputeType
A
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
0
,
0
,
k_
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
Float
B
>()(
i
)
=
b_thread_vec
.
template
AsType
<
ComputeType
B
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
0
,
0
,
k_
+
i
))
>
{}];
});
using
mfma_input_type_a
=
typename
vector_type
<
Float
A
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeType
A
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
using
mfma_input_type_b
=
typename
vector_type
<
Float
B
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeType
B
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
...
@@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
Float
A
,
ComputeType
A
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
...
...
@@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
Float
B
,
ComputeType
B
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
...
...
@@ -586,7 +595,9 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
typename
ComputeTypeA
=
FloatA
,
typename
ComputeTypeB
=
FloatB
>
constexpr
auto
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
()
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
...
...
@@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
ComputeTypeA
,
ComputeTypeB
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
...
...
@@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
ComputeTypeA
,
ComputeTypeB
>
{};
}
};
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp"
namespace
ck
{
/**
* @brief Blockwise data transfer with dequantization
*
* RunRead would load low-precision data and scale data.
* RunWrite would process dequantization process.
* Assume Scale is identical along K-dimension
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template
<
typename
ThreadGroup
,
typename
SrcElementwiseOperation
,
typename
ScaleElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
BlockSliceLengths
,
typename
BlockScaleSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
ScaleData
,
typename
DstData
,
typename
SrcDesc
,
typename
ScaleDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
ScaleScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
ScaleScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v4r1_dequant
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
scale_thread_slice_lengths
=
BlockScaleSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v4r1_dequant
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_block_slice_origin
,
const
ScaleElementwiseOperation
&
scale_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
DstElementwiseOperation
&
dst_element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
src_element_op
,
scale_desc
,
make_zero_multi_index
<
nDim
>
(),
scale_element_op
,
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
ScaleDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{}
&&
is_same
<
BlockScaleSliceLengths
,
decltype
(
scale_thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetScaleSliceOrigin
(
scale_desc
,
scale_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
thread_scratch_id
);
}
}
// With the assumption, scale scratch is always one
template
<
typename
ScaleBuffer
>
__device__
void
RunScaleRead
(
const
ScaleDesc
&
scale_desc
,
const
ScaleBuffer
&
scale_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunScaleRead
(
scale_desc
,
scale_buf
);
}
}
template
<
typename
DstBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst_buf
,
thread_scratch_id
);
}
}
// We don't prefer use this API directly
/*
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
*/
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
// With the assumption, scale buffer don't need move slice window method
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v3r1_dequant
<
decltype
(
thread_slice_lengths
),
decltype
(
scale_thread_slice_lengths
),
SrcElementwiseOperation
,
ScaleElementwiseOperation
,
DstElementwiseOperation
,
DstInMemOp
,
SrcData
,
ScaleData
,
DstData
,
SrcDesc
,
ScaleDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
ScaleScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
ScaleScalarStrideInVector
,
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Dequantization of input tensor could not be decoupled from gridwisegemm pipeline
// As input tensor thread buffer declared inside blockwise-gemm pipeline.
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemm_dequantB
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_scale
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
9de63596
...
...
@@ -62,10 +62,10 @@ template <index_t NumDimG,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
...
...
@@ -73,13 +73,14 @@ template <index_t NumDimG,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
DESpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K
0
PerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerW
MMA
,
ck
::
index_t
NPerW
MMA
,
ck
::
index_t
MPerW
mma
,
ck
::
index_t
NPerW
mma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
...
...
@@ -100,7 +101,6 @@ template <index_t NumDimG,
index_t
CShuffleNRepeatPerShuffle
,
typename
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
...
@@ -123,15 +123,32 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
static
constexpr
auto
AEnableLds_auto
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K
0
PerBlock
*
K1
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
static
auto
MakeAGridDescriptor
_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
static
auto
MakeAGridDescriptor
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
assert
(
a_gs_ms_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
&&
a_gs_ms_ks_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
);
...
...
@@ -158,36 +175,72 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_ks_lengths
,
kDimIds
);
if
constexpr
(
ASpec
==
TensorSpecialization
::
Packed
)
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
ASpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
a_ms_ks_strides
[
Number
<
NumDimM
-
1
>
{}],
a_ms_ks_strides
[
Number
<
NumDimM
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
a_grid_desc_ms_ks
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}();
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
if
constexpr
(
AEnableLds
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
a_ms_ks_strides
[
Number
<
NumDimM
-
1
>
{}],
a_ms_ks_strides
[
Number
<
NumDimM
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
a_grid_desc_ms_ks
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
constexpr
auto
A_KRow
=
2
;
constexpr
auto
A_K0PerWmma
=
WmmaK
/
A_KRow
/
K1Number
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_K0PerWmma
>
{},
Number
<
A_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor
_N_K
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
static
auto
MakeBGridDescriptor
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
{
assert
(
b_gs_ns_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
&&
b_gs_ns_ks_strides_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
);
...
...
@@ -214,30 +267,66 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_ns_ks_lengths
,
nDimIds
);
if
constexpr
(
BSpec
==
TensorSpecialization
::
Packed
)
const
auto
b_grid_desc_n_k
=
[
&
]()
{
if
constexpr
(
BSpec
==
TensorSpecialization
::
Packed
)
{
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
b_ns_ks_strides
[
Number
<
NumDimN
-
1
>
{}],
b_ns_ks_strides
[
Number
<
NumDimN
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
else
{
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_ns_ks
,
make_tuple
(
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}();
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
if
constexpr
(
BEnableLds
)
{
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
b_ns_ks_strides
[
Number
<
NumDimN
-
1
>
{}],
b_ns_ks_strides
[
Number
<
NumDimN
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_ns_ks
,
make_tuple
(
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
constexpr
auto
B_KRow
=
2
;
constexpr
auto
B_K0PerWmma
=
WmmaK
/
B_KRow
/
K1Number
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
N0
=
N
/
NPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_K0PerWmma
>
{},
Number
<
B_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
...
...
@@ -393,8 +482,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
({},
{}));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
({},
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
...
...
@@ -449,45 +536,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
};
// A desc for source in blockwise copy
template
<
typename
AGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_K0_M_K1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_K0_N_K1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
using
AGridDesc_K0_M_K1
=
decltype
(
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
AGridDesc_M_K
{}));
using
BGridDesc_K0_N_K1
=
decltype
(
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
BGridDesc_N_K
{}));
using
AGridDesc
=
decltype
(
DeviceOp
::
MakeAGridDescriptor
({},
{}));
using
BGridDesc
=
decltype
(
DeviceOp
::
MakeBGridDescriptor
({},
{}));
// GridwiseOp
using
GridwiseOp
=
GridwiseGemmMultipleD_
k0mk1_k0nk1_mn_wmma_cshuffle
<
using
GridwiseOp
=
GridwiseGemmMultipleD_
Wmma
<
// DataType Family
ADataType
,
BDataType
,
...
...
@@ -496,8 +549,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
DsDataType
,
EDataType
,
// InMemory Data Descriptor
AGridDesc
_K0_M_K1
,
BGridDesc
_K0_N_K1
,
AGridDesc
,
BGridDesc
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
// ElementwiseOp Family
...
...
@@ -508,9 +561,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Tiling Family
MPerBlock
,
NPerBlock
,
K
0
PerBlock
,
MPerW
MMA
,
NPerW
MMA
,
KPerBlock
,
MPerW
mma
,
NPerW
mma
,
K1
,
MRepeat
,
NRepeat
,
...
...
@@ -523,6 +576,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -531,6 +585,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
...
...
@@ -564,16 +619,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_
m_k_
{},
b_grid_desc_
n_k_
{},
a_grid_desc_
{},
b_grid_desc_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
ds_grid_desc_g_m_n_
{
DeviceOp
::
MakeDsGridDescriptor_G_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
)},
e_grid_desc_g_m_n_
{
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
)},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{},
e_grid_desc_mblock_mperblock_nblock_nperblock
{},
block_2_ctile_map_
{},
...
...
@@ -600,10 +653,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
});
a_grid_desc_m_k_
=
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
b_grid_desc_n_k_
=
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
a_grid_desc_
=
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
b_grid_desc_
=
DeviceOp
::
MakeBGridDescriptor
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
ds_grid_desc_m_n_
=
DeviceOp
::
MakeDsGridDescriptor_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
);
...
...
@@ -611,9 +662,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
a_grid_desc_m_k_
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
b_grid_desc_n_k_
);
block_2_ctile_map_
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
);
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -644,16 +692,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EDataType
*
p_e_grid_
;
// Tensor Descriptors
AGridDesc
_M_K
a_grid_desc_
m_k_
;
BGridDesc
_N_K
b_grid_desc_
n_k_
;
AGridDesc
a_grid_desc_
;
BGridDesc
b_grid_desc_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
DsGridDesc_G_M_N
ds_grid_desc_g_m_n_
;
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -686,6 +731,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Batch Offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
// for checking vector load/store
// index_t MRaw_;
// index_t NRaw_;
// index_t KRaw_;
};
// Invoker
...
...
@@ -700,8 +750,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
G
;
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
auto
K
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I2
);
}
else
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I3
)
*
arg
.
a_grid_desc_
.
GetLength
(
I4
)
*
arg
.
a_grid_desc_
.
GetLength
(
I6
);
}
}();
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
...
...
@@ -712,8 +771,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
DeviceOp
::
AGridDesc
_K0_M_K1
,
DeviceOp
::
BGridDesc
_K0_N_K1
,
DeviceOp
::
AGridDesc
,
DeviceOp
::
BGridDesc
,
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
...
...
@@ -733,8 +792,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
G
,
arg
.
a_grid_desc_
k0_m_k1_
,
arg
.
b_grid_desc_
k0_n_k1_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
...
...
@@ -774,6 +833,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
printf
(
"DeviceOp: Arch check failure
\n
"
);
return
false
;
}
}
...
...
@@ -782,12 +842,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
return
false
;
}
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_
k0_m_k1_
,
arg
.
b_grid_desc_
k0_n_k1_
,
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
printf
(
"GridwiseOp: Validity check failure
\n
"
);
return
false
;
}
...
...
@@ -800,16 +861,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if
constexpr
(
ABlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
a_mz_stride_
==
1
&&
arg
.
a_grid_desc_
k0_m_k1_
.
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
arg
.
a_grid_desc_
.
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
printf
(
"DeviceOp: Vector Access A-m check failure
\n
"
);
return
false
;
}
}
else
{
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
arg
.
a_grid_desc_
k0_m_k1_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
arg
.
a_grid_desc_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
printf
(
"DeviceOp: Vector Access A-k check failure
\n
"
);
return
false
;
}
}
...
...
@@ -818,16 +881,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if
constexpr
(
BBlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
b_nz_stride_
==
1
&&
arg
.
b_grid_desc_
k0_n_k1_
.
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
arg
.
b_grid_desc_
.
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
printf
(
"DeviceOp: Vector Access B-n check failure
\n
"
);
return
false
;
}
}
else
{
if
(
!
(
arg
.
b_kz_stride_
==
1
&&
arg
.
b_grid_desc_
k0_n_k1_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
arg
.
b_grid_desc_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
printf
(
"DeviceOp: Vector Access B-k check failure
\n
"
);
return
false
;
}
}
...
...
@@ -841,6 +906,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
printf
(
"DeviceOp: Vector Access D-n check failure
\n
"
);
valid_d_access
=
false
;
}
});
...
...
@@ -857,6 +923,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
0
)
||
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
1
))
{
printf
(
"DeviceOp: Vector Access E-n check failure
\n
"
);
return
false
;
}
...
...
@@ -967,14 +1034,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K
0
PerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
K1
<<
", "
<<
MPerW
MMA
<<
", "
<<
NPerW
MMA
<<
", "
<<
MPerW
mma
<<
", "
<<
NPerW
mma
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
">"
<<
" NumPrefetch: "
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"BEnableLds: "
<<
BEnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
DeviceOp
,
typename
GridwiseOp
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
B0DataType
*
__restrict__
p_b0_grid
,
const
B1DataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
O
,
index_t
G0
,
index_t
G1
,
float
alpha
,
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
constexpr
index_t
array_size
=
4
;
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_strides
=
output_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
const
auto
a_element_op
=
AElementwiseOperation
{};
const
auto
b0_element_op
=
B0ElementwiseOperation
{};
const
auto
acc0_element_op
=
AccElementwiseOperation
{
alpha
};
const
auto
b1_element_op
=
B1ElementwiseOperation
{};
const
auto
c_element_op
=
CElementwiseOperation
{};
// fail to reuse DeviceOp::MakeArgument() because of the __device__ function required.
const
auto
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
const
auto
b0_grid_desc
=
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
const
auto
b1_grid_desc
=
DeviceOp
::
MakeB1GridDescriptor
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
auto
block_2_ctile_map
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
1
,
1
);
const
auto
a_grid_desc_g_m_k
=
DeviceOp
::
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
const
auto
b0_grid_desc_g_l_k
=
DeviceOp
::
Transform
::
MakeB0GridDescriptor_G_N_K
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
const
auto
b1_grid_desc_g_n_l
=
DeviceOp
::
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
compute_base_ptr_of_batch
=
typename
DeviceOp
::
ComputeBasePtrOfStridedBatch
{
a_grid_desc_g_m_k
,
b0_grid_desc_g_l_k
,
b1_grid_desc_g_n_l
,
c_grid_desc_g_m_n
};
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
Number
<
0
>
{});
const
auto
c0_matrix_mask
=
typename
DeviceOp
::
C0MatrixMask
{
b0_grid_desc_g_l_k
.
GetLength
(
Number
<
1
>
{})};
// clang-format on
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB0BasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b0_grid
+
b0_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
,
c0_matrix_mask
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b0_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
M
;
ignore
=
N
;
ignore
=
K
;
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
}
// Self-Attention
template
<
typename
DeviceOp
,
typename
GridwiseOp
,
typename
QKVDataType
,
typename
ODataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_wmma_self_attention_forward
(
const
QKVDataType
*
__restrict__
p_qkv_grid
,
ODataType
*
__restrict__
p_out_grid
,
index_t
batch_size
,
index_t
sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize]
constexpr
index_t
array_size
=
4
;
std
::
array
<
ck
::
index_t
,
array_size
>
qk_gs_ms_ks_lengths
{
batch_size
,
head_count
,
sequence_length
,
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
qk_gs_ms_ks_strides
{
sequence_length
*
head_count
*
3
*
head_size
,
3
*
head_size
,
head_count
*
3
*
head_size
,
1
};
std
::
array
<
ck
::
index_t
,
array_size
>
v_gs_os_ns_lengths
{
batch_size
,
head_count
,
head_size
,
sequence_length
};
std
::
array
<
ck
::
index_t
,
array_size
>
v_gs_os_ns_strides
{
sequence_length
*
head_count
*
3
*
head_size
,
3
*
head_size
,
1
,
head_count
*
3
*
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_lengths
{
batch_size
,
head_count
,
sequence_length
,
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_strides
{
sequence_length
*
head_count
*
head_size
,
head_size
,
head_count
*
head_size
,
1
};
const
auto
a_element_op
=
AElementwiseOperation
{};
const
auto
b0_element_op
=
B0ElementwiseOperation
{};
const
auto
acc0_element_op
=
AccElementwiseOperation
{
alpha
};
const
auto
b1_element_op
=
B1ElementwiseOperation
{};
const
auto
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
qk_gs_ms_ks_lengths
,
qk_gs_ms_ks_strides
);
const
auto
b0_grid_desc
=
DeviceOp
::
MakeB0GridDescriptor
(
qk_gs_ms_ks_lengths
,
qk_gs_ms_ks_strides
);
const
auto
b1_grid_desc
=
DeviceOp
::
MakeB1GridDescriptor
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
auto
block_2_ctile_map
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
1
,
1
);
const
auto
a_grid_desc_g_m_k
=
DeviceOp
::
Transform
::
MakeAGridDescriptor_G_M_K
(
qk_gs_ms_ks_lengths
,
qk_gs_ms_ks_strides
);
const
auto
b0_grid_desc_g_l_k
=
DeviceOp
::
Transform
::
MakeB0GridDescriptor_G_N_K
(
qk_gs_ms_ks_lengths
,
qk_gs_ms_ks_strides
);
const
auto
b1_grid_desc_g_n_l
=
DeviceOp
::
Transform
::
MakeB1GridDescriptor_G_N_K
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
compute_base_ptr_of_batch
=
typename
DeviceOp
::
ComputeBasePtrOfStridedBatch
{
a_grid_desc_g_m_k
,
b0_grid_desc_g_l_k
,
b1_grid_desc_g_n_l
,
c_grid_desc_g_m_n
};
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
Number
<
0
>
{});
const
auto
c0_matrix_mask
=
typename
DeviceOp
::
C0MatrixMask
{
b0_grid_desc_g_l_k
.
GetLength
(
Number
<
1
>
{})};
// clang-format on
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB0BasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
index_t
qkv_gap
=
__builtin_amdgcn_readfirstlane
(
head_size
);
#ifdef CK_SELF_ATTN_DEBUG
if
(
get_thread_global_1d_id
()
==
0
)
{
printf
(
"batch_size: %d
\n
"
,
batch_size
);
printf
(
"sequence_length: %d
\n
"
,
sequence_length
);
printf
(
"head_count: %d
\n
"
,
head_count
);
printf
(
"head_size: %d
\n
"
,
head_size
);
printf
(
"qkv_gap: %d
\n
"
,
qkv_gap
);
printf
(
"get_grid_size(): %d
\n
"
,
get_grid_size
());
printf
(
"batch_count: %d
\n
"
,
batch_count
);
printf
(
"blockid: %d
\n
"
,
get_block_1d_id
());
printf
(
"num_blocks_per_batch: %d
\n
"
,
num_blocks_per_batch
);
printf
(
"g_idx: %d
\n
"
,
g_idx
);
printf
(
"a_batch_offset: %ld
\n
"
,
a_batch_offset
);
printf
(
"b0_batch_offset: %ld
\n
"
,
b0_batch_offset
);
printf
(
"b1_batch_offset: %ld
\n
"
,
b1_batch_offset
);
}
#endif
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_qkv_grid
+
0
*
qkv_gap
+
a_batch_offset
,
p_qkv_grid
+
1
*
qkv_gap
+
b0_batch_offset
,
p_qkv_grid
+
2
*
qkv_gap
+
b1_batch_offset
,
p_out_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
,
c0_matrix_mask
,
block_2_ctile_map
);
#else
ignore
=
p_qkv_grid
;
ignore
=
p_out_grid
;
ignore
=
batch_size
;
ignore
=
sequence_length
;
ignore
=
head_count
;
ignore
=
head_size
;
ignore
=
alpha
;
#endif // end of if (defined(__gfx11__))
}
// Cross-Attention
// Self-Attention
template
<
typename
DeviceOp
,
typename
GridwiseOp
,
typename
QDataType
,
typename
KVDataType
,
typename
ODataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_wmma_cross_attention_forward
(
const
QDataType
*
__restrict__
p_q_grid
,
const
KVDataType
*
__restrict__
p_kv_grid
,
ODataType
*
__restrict__
p_out_grid
,
index_t
batch_size
,
index_t
q_sequence_length
,
index_t
kv_sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize]
constexpr
index_t
array_size
=
4
;
std
::
array
<
ck
::
index_t
,
array_size
>
q_gs_ms_ks_lengths
{
batch_size
,
head_count
,
q_sequence_length
,
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
q_gs_ms_ks_strides
{
q_sequence_length
*
head_count
*
head_size
,
head_size
,
head_count
*
head_size
,
1
};
std
::
array
<
ck
::
index_t
,
array_size
>
k_gs_ms_ks_lengths
{
batch_size
,
head_count
,
kv_sequence_length
,
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
k_gs_ms_ks_strides
{
kv_sequence_length
*
head_count
*
2
*
head_size
,
2
*
head_size
,
head_count
*
2
*
head_size
,
1
};
std
::
array
<
ck
::
index_t
,
array_size
>
v_gs_os_ns_lengths
{
batch_size
,
head_count
,
head_size
,
kv_sequence_length
};
std
::
array
<
ck
::
index_t
,
array_size
>
v_gs_os_ns_strides
{
kv_sequence_length
*
head_count
*
2
*
head_size
,
2
*
head_size
,
1
,
head_count
*
2
*
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_lengths
{
batch_size
,
head_count
,
q_sequence_length
,
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_strides
{
q_sequence_length
*
head_count
*
head_size
,
head_size
,
head_count
*
head_size
,
1
};
const
auto
a_element_op
=
AElementwiseOperation
{};
const
auto
b0_element_op
=
B0ElementwiseOperation
{};
const
auto
acc0_element_op
=
AccElementwiseOperation
{
alpha
};
const
auto
b1_element_op
=
B1ElementwiseOperation
{};
const
auto
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
const
auto
b0_grid_desc
=
DeviceOp
::
MakeB0GridDescriptor
(
k_gs_ms_ks_lengths
,
k_gs_ms_ks_strides
);
const
auto
b1_grid_desc
=
DeviceOp
::
MakeB1GridDescriptor
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
auto
block_2_ctile_map
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
1
,
1
);
const
auto
a_grid_desc_g_m_k
=
DeviceOp
::
Transform
::
MakeAGridDescriptor_G_M_K
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
const
auto
b0_grid_desc_g_l_k
=
DeviceOp
::
Transform
::
MakeB0GridDescriptor_G_N_K
(
k_gs_ms_ks_lengths
,
k_gs_ms_ks_strides
);
const
auto
b1_grid_desc_g_n_l
=
DeviceOp
::
Transform
::
MakeB1GridDescriptor_G_N_K
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
compute_base_ptr_of_batch
=
typename
DeviceOp
::
ComputeBasePtrOfStridedBatch
{
a_grid_desc_g_m_k
,
b0_grid_desc_g_l_k
,
b1_grid_desc_g_n_l
,
c_grid_desc_g_m_n
};
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
Number
<
0
>
{});
const
auto
c0_matrix_mask
=
typename
DeviceOp
::
C0MatrixMask
{
b0_grid_desc_g_l_k
.
GetLength
(
Number
<
1
>
{})};
// clang-format on
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB0BasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
index_t
kv_gap
=
__builtin_amdgcn_readfirstlane
(
head_size
);
#ifdef CK_SELF_ATTN_DEBUG
if
(
get_thread_global_1d_id
()
==
0
)
{
printf
(
"batch_size: %d
\n
"
,
batch_size
);
printf
(
"q_sequence_length: %d
\n
"
,
q_sequence_length
);
printf
(
"k_sequence_length: %d
\n
"
,
kv_sequence_length
);
printf
(
"head_count: %d
\n
"
,
head_count
);
printf
(
"head_size: %d
\n
"
,
head_size
);
printf
(
"kv_gap: %d
\n
"
,
kv_gap
);
printf
(
"get_grid_size(): %d
\n
"
,
get_grid_size
());
printf
(
"batch_count: %d
\n
"
,
batch_count
);
printf
(
"blockid: %d
\n
"
,
get_block_1d_id
());
printf
(
"num_blocks_per_batch: %d
\n
"
,
num_blocks_per_batch
);
printf
(
"g_idx: %d
\n
"
,
g_idx
);
printf
(
"a_batch_offset: %ld
\n
"
,
a_batch_offset
);
printf
(
"b0_batch_offset: %ld
\n
"
,
b0_batch_offset
);
printf
(
"b1_batch_offset: %ld
\n
"
,
b1_batch_offset
);
}
#endif
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_q_grid
+
a_batch_offset
,
p_kv_grid
+
0
*
kv_gap
+
b0_batch_offset
,
p_kv_grid
+
1
*
kv_gap
+
b1_batch_offset
,
p_out_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
,
c0_matrix_mask
,
block_2_ctile_map
);
#else
ignore
=
p_q_grid
;
ignore
=
p_kv_grid
;
ignore
=
p_out_grid
;
ignore
=
batch_size
;
ignore
=
q_sequence_length
;
ignore
=
kv_sequence_length
;
ignore
=
head_count
;
ignore
=
head_size
;
ignore
=
alpha
;
#endif // end of if (defined(__gfx11__))
}
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimL
,
index_t
NumDimK
,
index_t
NumDimN
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc0DataType
,
typename
Acc1BiasDataType
,
typename
Acc1DataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
B0Spec
,
TensorSpecialization
B1Spec
,
TensorSpecialization
CSpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
LTilePerBlock
,
ck
::
index_t
L1
,
ck
::
index_t
MPerWmma
,
ck
::
index_t
LPerWmma
,
ck
::
index_t
NPerWmma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
LRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
B0BlockTransferThreadClusterLengths_K0_L_K1
,
typename
B0BlockTransferThreadClusterArrangeOrder
,
typename
B0BlockTransferSrcAccessOrder
,
ck
::
index_t
B0BlockTransferSrcVectorDim
,
ck
::
index_t
B0BlockTransferSrcScalarPerVector
,
ck
::
index_t
B0BlockTransferDstScalarPerVector_K1
,
bool
B0BlockLdsAddExtraL
,
typename
B1BlockTransferThreadClusterLengths_L0_N_L1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
ck
::
index_t
B1BlockTransferSrcVectorDim
,
ck
::
index_t
B1BlockTransferSrcScalarPerVector
,
ck
::
index_t
B1BlockTransferDstScalarPerVector_L1
,
bool
B1BlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
:
public
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimL
>
0
&&
NumDimK
>
0
&&
NumDimN
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
// TODO ANT: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static
constexpr
index_t
NumDimGemm0M
=
NumDimM
;
static
constexpr
index_t
NumDimGemm0N
=
NumDimL
;
static
constexpr
index_t
NumDimGemm0K
=
NumDimK
;
static
constexpr
index_t
NumDimGemm1M
=
NumDimM
;
static
constexpr
index_t
NumDimGemm1N
=
NumDimN
;
static
constexpr
index_t
NumDimGemm1K
=
NumDimL
;
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
AEnableLds_auto
=
LWaves
==
1
?
false
:
true
;
static
constexpr
auto
B0EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
B1EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
B0EnableLds_manu
=
true
;
static
constexpr
auto
B1EnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B0EnableLds
=
B0EnableLds_auto
||
B0EnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B1EnableLds
=
B1EnableLds_auto
||
B1EnableLds_manu
||
(
NumPrefetch
>
1
);
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
MPerBlock
,
LPerBlock
,
KPerBlock
,
NPerBlock
>
,
GemmSpec
,
ASpec
,
B0Spec
,
B1Spec
,
CSpec
>
;
__host__
__device__
static
auto
MakeAGridDescriptor
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
if
constexpr
(
AEnableLds
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
AK1
>
{});
}
else
{
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
AK1
>
{});
}
}
__host__
__device__
static
auto
MakeB0GridDescriptor
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_strides_vec
)
{
if
constexpr
(
B0EnableLds
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
BK1
>
{});
}
else
{
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{},
Number
<
BK1
>
{});
}
}
__host__
__device__
static
auto
MakeB1GridDescriptor
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_strides_vec
)
{
if
constexpr
(
B1EnableLds
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
L1
>
{});
}
else
{
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{},
Number
<
L1
>
{});
}
}
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
({},
{}));
using
B0GridDesc
=
decltype
(
MakeB0GridDescriptor
({},
{}));
using
B1GridDesc
=
decltype
(
MakeB1GridDescriptor
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
B0GridDesc_G_L_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_L
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
__host__
__device__
constexpr
static
auto
make_MaskOutPredicate
()
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
return
MaskOutUpperTrianglePredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
struct
ComputeBasePtrOfStridedBatch
{
__host__
__device__
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
B0GridDesc_G_L_K
&
b0_grid_desc_g_l_k
,
const
B1GridDesc_G_N_L
&
b1_grid_desc_g_n_l
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b0_grid_desc_g_l_k_
(
b0_grid_desc_g_l_k
),
b1_grid_desc_g_n_l_
(
b1_grid_desc_g_n_l
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
{
}
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
{
return
a_grid_desc_g_m_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB0BasePtr
(
index_t
g_idx
)
const
{
return
b0_grid_desc_g_l_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
return
b1_grid_desc_g_n_l_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
{
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
B0GridDesc_G_L_K
b0_grid_desc_g_l_k_
;
B1GridDesc_G_N_L
b1_grid_desc_g_n_l_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
};
// GridwiseOp
using
GridwiseOp
=
GridwiseBatchedGemmSoftmaxGemm_Wmma
<
// DataType Family
ADataType
,
B0DataType
,
Acc0DataType
,
B1DataType
,
Acc1DataType
,
CShuffleDataType
,
CDataType
,
// ElementwiseOp Family
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// InMemory Data Descriptor
AGridDesc
,
B0GridDesc
,
B1GridDesc
,
CGridDesc_M_N
,
// Tiling Family
MPerBlock
,
LPerBlock
,
KPerBlock
,
AK1
,
BK1
,
NPerBlock
,
LTilePerBlock
,
L1
,
MPerWmma
,
LPerWmma
,
NPerWmma
,
MRepeat
,
LRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
true
,
AEnableLds
,
ABlockLdsAddExtraM
,
B0BlockTransferThreadClusterLengths_K0_L_K1
,
B0BlockTransferThreadClusterArrangeOrder
,
B0BlockTransferSrcAccessOrder
,
B0BlockTransferSrcVectorDim
,
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferDstScalarPerVector_K1
,
true
,
B0EnableLds
,
B0BlockLdsAddExtraL
,
B1BlockTransferThreadClusterLengths_L0_N_L1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_L1
,
false
,
B1EnableLds
,
B1BlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
struct
RawArg
:
public
BaseArgument
{
RawArg
(
const
ADataType
*
p_a_grid
,
const
B0DataType
*
p_b0_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
O
,
index_t
G0
,
index_t
G1
,
float
alpha
,
bool
input_permute
,
bool
output_permute
)
:
p_a_grid_
{
p_a_grid
},
p_b0_grid_
{
p_b0_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
M_
{
M
},
N_
{
N
},
K_
{
K
},
O_
{
O
},
G0_
{
G0
},
G1_
{
G1
},
alpha_
{
alpha
},
input_permute_
{
input_permute
},
output_permute_
{
output_permute
}
{
}
// Pointers
const
ADataType
*
p_a_grid_
;
const
B0DataType
*
p_b0_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
// Raw Problem Size
index_t
M_
;
index_t
N_
;
index_t
K_
;
index_t
O_
;
index_t
G0_
;
index_t
G1_
;
float
alpha_
;
bool
input_permute_
;
bool
output_permute_
;
};
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
B0DataType
*
p_b0
,
const
B1DataType
*
p_b1
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
O
,
index_t
G0
,
index_t
G1
,
float
alpha
,
bool
input_permute
,
bool
output_permute
)
{
return
RawArg
{
p_a
,
p_b0
,
p_b1
,
p_c
,
M
,
N
,
K
,
O
,
G0
,
G1
,
alpha
,
input_permute
,
output_permute
};
}
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_navi3_supported
())
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
printf
(
"DeviceOp: Acc0 Type err"
);
return
false
;
}
if
constexpr
(
!
(
is_same_v
<
Acc1DataType
,
float
>
||
is_same_v
<
Acc1DataType
,
int32_t
>
))
{
printf
(
"DeviceOp: Acc1 Type err"
);
return
false
;
}
}
else
{
printf
(
"DeviceOp: Arch err"
);
return
false
;
}
constexpr
index_t
array_size
=
4
;
ck
::
index_t
G0
=
arg
.
G0_
;
ck
::
index_t
G1
=
arg
.
G1_
;
ck
::
index_t
M
=
arg
.
M_
;
ck
::
index_t
N
=
arg
.
N_
;
ck
::
index_t
K
=
arg
.
K_
;
ck
::
index_t
O
=
arg
.
O_
;
bool
input_permute
=
arg
.
input_permute_
;
bool
output_permute
=
arg
.
output_permute_
;
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_strides
=
output_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
const
auto
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
const
auto
b0_grid_desc
=
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
const
auto
b1_grid_desc
=
DeviceOp
::
MakeB1GridDescriptor
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
block_2_ctile_map
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
1
,
1
);
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
Number
<
0
>
{});
if
(
!
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_m_n
,
block_2_ctile_map
))
{
return
false
;
}
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
// unpadded
if
(
!
(
c_g
==
batch_count
))
{
printf
(
"DeviceOp: BatchCount err"
);
return
false
;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const
auto
MzRaw
=
M
;
const
auto
LzRaw
=
N
;
const
auto
KzRaw
=
K
;
const
auto
NzRaw
=
O
;
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
b0_extent_lowest
=
B0BlockTransferSrcVectorDim
==
2
?
KzRaw
:
LzRaw
;
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
LzRaw
:
NzRaw
;
const
auto
c_extent_lowest
=
NzRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b0_extent_lowest
%
B0BlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
printf
(
"DeviceOp: Data Transfer Vector scalar err"
);
return
false
;
}
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]};
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_lz_kz_strides_
{
b0_gs_ns_ks_strides
[
NumDimG
+
NumDimL
-
1
],
b0_gs_ns_ks_strides
[
NumDimG
+
NumDimL
+
NumDimK
-
1
]};
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_nz_lz_strides_
{
b1_gs_os_ns_strides
[
NumDimG
+
NumDimN
-
1
],
b1_gs_os_ns_strides
[
NumDimG
+
NumDimN
+
NumDimL
-
1
]};
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_mz_nz_strides_
{
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimN
-
1
]};
// Check vector load/store requirement
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
a_mz_kz_strides_
[
1
]
:
a_mz_kz_strides_
[
0
];
const
auto
b0_stride_lowest
=
B0BlockTransferSrcVectorDim
==
2
?
b0_lz_kz_strides_
[
1
]
:
b0_lz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
b1_nz_lz_strides_
[
1
]
:
b1_nz_lz_strides_
[
0
];
const
auto
c_stride_lowest
=
c_mz_nz_strides_
[
1
];
if
(
!
(
a_stride_lowest
==
1
||
b0_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
{
printf
(
"DeviceOp: Data Vectorize transfer err"
);
return
false
;
}
return
true
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
RawArg
*>
(
p_arg
));
}
struct
SelfAttnArg
:
public
BaseArgument
{
SelfAttnArg
(
const
ADataType
*
p_qkv_grid
,
CDataType
*
p_out_grid
,
index_t
batch_size
,
index_t
sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
:
p_qkv_grid_
{
p_qkv_grid
},
p_out_grid_
{
p_out_grid
},
batch_size_
{
batch_size
},
sequence_length_
{
sequence_length
},
head_count_
{
head_count
},
head_size_
{
head_size
},
alpha_
{
alpha
}
{
}
// Pointers
const
ADataType
*
p_qkv_grid_
;
CDataType
*
p_out_grid_
;
// Raw Problem Size
index_t
batch_size_
;
index_t
sequence_length_
;
index_t
head_count_
;
index_t
head_size_
;
float
alpha_
;
};
static
auto
MakeSelfAttnArgument
(
const
ADataType
*
p_qkv_grid
,
CDataType
*
p_out_grid
,
index_t
batch_size
,
index_t
sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
{
return
SelfAttnArg
{
p_qkv_grid
,
p_out_grid
,
batch_size
,
sequence_length
,
head_count
,
head_size
,
alpha
};
}
struct
CrossAttnArg
:
public
BaseArgument
{
CrossAttnArg
(
const
ADataType
*
p_q_grid
,
const
B0DataType
*
p_kv_grid
,
CDataType
*
p_out_grid
,
index_t
batch_size
,
index_t
q_sequence_length
,
index_t
kv_sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
:
p_q_grid_
{
p_q_grid
},
p_kv_grid_
{
p_kv_grid
},
p_out_grid_
{
p_out_grid
},
batch_size_
{
batch_size
},
q_sequence_length_
{
q_sequence_length
},
kv_sequence_length_
{
kv_sequence_length
},
head_count_
{
head_count
},
head_size_
{
head_size
},
alpha_
{
alpha
}
{
}
// Pointers
const
ADataType
*
p_q_grid_
;
const
B0DataType
*
p_kv_grid_
;
CDataType
*
p_out_grid_
;
// Raw Problem Size
index_t
batch_size_
;
index_t
q_sequence_length_
;
index_t
kv_sequence_length_
;
index_t
head_count_
;
index_t
head_size_
;
float
alpha_
;
};
static
auto
MakeCrossAttnArgument
(
const
ADataType
*
p_q_grid
,
const
B0DataType
*
p_kv_grid
,
CDataType
*
p_out_grid
,
index_t
batch_size
,
index_t
q_sequence_length
,
index_t
kv_sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
{
return
CrossAttnArg
{
p_q_grid
,
p_kv_grid
,
p_out_grid
,
batch_size
,
q_sequence_length
,
kv_sequence_length
,
head_count
,
head_size
,
alpha
};
}
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
B0DataType
*
p_b0_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_ns_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_strides
,
const
index_t
M01
,
const
index_t
N01
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b0_grid_
{
p_b0_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc
{
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc
{
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc
{
DeviceOp
::
MakeB1GridDescriptor
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc_g_l_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc_g_n_l_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
)},
a_element_op_
{
a_element_op
},
b0_element_op_
{
b0_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b0_grid_desc_g_l_k_
.
GetLength
(
I1
)},
raw_lengths_mz_lz_kz_nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b0_gs_ls_ks_lengths
[
NumDimG
+
NumDimL
-
1
],
b0_gs_ls_ks_lengths
[
NumDimG
+
NumDimL
+
NumDimK
-
1
],
b1_gs_ns_ls_lengths
[
NumDimG
+
NumDimN
-
1
]},
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]},
b0_lz_kz_strides_
{
b0_gs_ls_ks_strides
[
NumDimG
+
NumDimL
-
1
],
b0_gs_ls_ks_strides
[
NumDimG
+
NumDimL
+
NumDimK
-
1
]},
b1_nz_lz_strides_
{
b1_gs_ns_ls_strides
[
NumDimG
+
NumDimN
-
1
],
b1_gs_ns_ls_strides
[
NumDimG
+
NumDimN
+
NumDimL
-
1
]},
c_mz_nz_strides_
{
c_gs_ms_ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_ns_strides
[
NumDimG
+
NumDimM
+
NumDimN
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_ptr_offset_of_batch_
{
a_grid_desc_g_m_k_
,
b0_grid_desc_g_l_k_
,
b1_grid_desc_g_n_l_
,
c_grid_desc_g_m_n_
}
{
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ls_lengths
;
ignore
=
acc0_biases_gs_ms_ls_strides
;
ignore
=
acc1_biases_gs_ms_ns_lengths
;
ignore
=
acc1_biases_gs_ms_ns_strides
;
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
// Pointers
const
ADataType
*
p_a_grid_
;
const
B0DataType
*
p_b0_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
// Tensor Descriptors
AGridDesc
a_grid_desc
;
B0GridDesc
b0_grid_desc
;
B1GridDesc
b1_grid_desc
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
B0GridDesc_G_L_K
b0_grid_desc_g_l_k_
;
B1GridDesc_G_N_L
b1_grid_desc_g_n_l_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
// Block to Tile mapping
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
// ElementwiseOp
AElementwiseOperation
a_element_op_
;
B0ElementwiseOperation
b0_element_op_
;
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
// Strides for the last M/N/K dimensions of A/B0/B1/C
// for sanity check of vector load/store
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
raw_lengths_mz_lz_kz_nz_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_mz_kz_strides_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_lz_kz_strides_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_nz_lz_strides_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_mz_nz_strides_
;
index_t
batch_count_
;
// Batch Offset
ComputeBasePtrOfStridedBatch
compute_ptr_offset_of_batch_
;
};
// Invoker
struct
SelfAttnInvoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
SelfAttnArg
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
M0
=
math
::
integer_divide_ceil
(
arg
.
sequence_length_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
arg
.
head_size_
,
NPerBlock
);
const
index_t
grid_size
=
arg
.
batch_size_
*
arg
.
head_count_
*
M0
*
N0
;
const
auto
K
=
arg
.
head_size_
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_wmma_self_attention_forward
<
DeviceOp
,
GridwiseOp
,
ADataType
,
CDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_qkv_grid_
,
arg
.
p_out_grid_
,
arg
.
batch_size_
,
arg
.
sequence_length_
,
arg
.
head_count_
,
arg
.
head_size_
,
arg
.
alpha_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
auto
MakeSelfAttnInvoker
()
{
return
SelfAttnInvoker
{};
}
// Invoker
struct
CrossAttnInvoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
CrossAttnArg
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
M0
=
math
::
integer_divide_ceil
(
arg
.
q_sequence_length_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
arg
.
head_size_
,
NPerBlock
);
const
index_t
grid_size
=
arg
.
batch_size_
*
arg
.
head_count_
*
M0
*
N0
;
const
auto
K
=
arg
.
head_size_
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_wmma_cross_attention_forward
<
DeviceOp
,
GridwiseOp
,
ADataType
,
B0DataType
,
CDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_q_grid_
,
arg
.
p_kv_grid_
,
arg
.
p_out_grid_
,
arg
.
batch_size_
,
arg
.
q_sequence_length_
,
arg
.
kv_sequence_length_
,
arg
.
head_count_
,
arg
.
head_size_
,
arg
.
alpha_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
auto
MakeCrossAttnInvoker
()
{
return
CrossAttnInvoker
{};
}
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
RawArg
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
M0
=
math
::
integer_divide_ceil
(
arg
.
M_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
arg
.
O_
,
NPerBlock
);
const
index_t
grid_size
=
arg
.
G0_
*
arg
.
G1_
*
M0
*
N0
;
const
auto
K
=
arg
.
K_
;
// printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K));
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
DeviceOp
,
GridwiseOp
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b0_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
M_
,
arg
.
N_
,
arg
.
K_
,
arg
.
O_
,
arg
.
G0_
,
arg
.
G1_
,
arg
.
alpha_
,
arg
.
input_permute_
,
arg
.
output_permute_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
printf("DeviceOp: Acc0 Type err");
return false;
}
if constexpr(!(is_same_v<Acc1DataType, float> || is_same_v<Acc1DataType, int32_t>))
{
printf("DeviceOp: Acc1 Type err");
return false;
}
}
else
{
printf("DeviceOp: Arch err");
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
arg.b1_grid_desc,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
if(!(c_g == arg.batch_count_))
{
printf("DeviceOp: BatchCount err");
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0];
const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1];
const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2];
const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3];
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
const auto c_extent_lowest = NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
printf("DeviceOp: Data Transfer Vector scalar err");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0];
const auto c_stride_lowest = arg.c_mz_nz_strides_[1];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
printf("DeviceOp: Data Vectorize transfer err");
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const ADataType* p_a,
const B0DataType* p_b0,
const B1DataType* p_b1,
CDataType* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b0,
p_b1,
p_c,
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ls_ks_lengths,
b0_gs_ls_ks_strides,
b1_gs_ns_ls_lengths,
b1_gs_ns_ls_strides,
c_gs_ms_ns_lengths,
c_gs_ms_ns_strides,
acc0_biases_gs_ms_ls_lengths,
acc0_biases_gs_ms_ls_strides,
acc1_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_ns_strides,
1,
1,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op};
}
#endif
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b1
,
void
*
p_c
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
{
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_strides
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_strides
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_strides
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_strides
;
std
::
transform
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
end
(),
a_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
end
(),
a_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b0_gs_ls_ks_lengths
.
begin
(),
b0_gs_ls_ks_lengths
.
end
(),
b0_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b0_gs_ls_ks_strides
.
begin
(),
b0_gs_ls_ks_strides
.
end
(),
b0_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b1_gs_ns_ls_lengths
.
begin
(),
b1_gs_ns_ls_lengths
.
end
(),
b1_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b1_gs_ns_ls_strides
.
begin
(),
b1_gs_ns_ls_strides
.
end
(),
b1_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
c_gs_ms_ns_lengths
.
begin
(),
c_gs_ms_ns_lengths
.
end
(),
c_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
c_gs_ms_ns_strides
.
begin
(),
c_gs_ms_ns_strides
.
end
(),
c_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
B0DataType
*>
(
p_b0
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
p_acc0_biases
,
p_acc1_biases
,
a_lengths
,
a_strides
,
b0_lengths
,
b0_strides
,
b1_lengths
,
b1_strides
,
c_lengths
,
c_strides
,
acc0_biases_gs_ms_ls_lengths
,
acc0_biases_gs_ms_ls_strides
,
acc1_biases_gs_ms_ns_lengths
,
acc1_biases_gs_ms_ns_strides
,
1
,
1
,
a_element_op
,
b0_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
LTilePerBlock
<<
", "
<<
L1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
B0Spec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"B0EnableLds: "
<<
B0EnableLds
<<
", "
<<
"B1EnableLds: "
<<
B1EnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -322,6 +322,19 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceElementwiseNormalizationImpl<"
;
str
<<
NumDim
<<
", "
;
str
<<
MPerThread
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
0 → 100644
View file @
9de63596
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N)
// 2. C(M, N) = A(M, K) * DequantB(K, N)
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
ScaleDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerWmma
,
ck
::
index_t
NPerWmma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
weight_only
>
struct
DeviceFpAintBGemm_Wmma_CShuffle
:
public
DeviceGemm_dequantB
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
static
constexpr
auto
AEnableLds_auto
=
(
NWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
// If true, LDS is used unconditionally
// LDS bypass feature not implemented for dequantization pipeline.
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
using
DeviceOp
=
DeviceFpAintBGemm_Wmma_CShuffle
;
// Describe how data read from Global memory
static
auto
MakeAGridDescriptor
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}();
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
if
constexpr
(
AEnableLds
)
{
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
constexpr
auto
A_KRow
=
2
;
constexpr
auto
A_K0PerWmma
=
WmmaK
/
A_KRow
/
K1Number
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_K0PerWmma
>
{},
Number
<
A_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
static
auto
MakeBGridDescriptor
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_n_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}();
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
if
constexpr
(
BEnableLds
)
{
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
constexpr
auto
B_KRow
=
2
;
constexpr
auto
B_K0PerWmma
=
WmmaK
/
B_KRow
/
K1Number
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
N0
=
N
/
NPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_K0PerWmma
>
{},
Number
<
B_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
static
auto
MakeScaleGridDescriptor
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
=
0
)
{
// assume Scale is [1, N]
const
auto
scale_grid_desc_n_k
=
[
&
]()
{
const
auto
scale_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
return
matrix_padder
.
PadBDescriptor_N_K
(
scale_grid_desc_nraw_kraw
);
}();
const
auto
N
=
scale_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
scale_grid_desc_n_k
.
GetLength
(
I1
);
// When K = 1, it might be scale tensor.
assert
(
K
%
K1
==
0
&&
K
!=
1
);
if
constexpr
(
BEnableLds
)
{
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
scale_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
1
)),
// Reduce K1 = 1
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
constexpr
auto
B_KRow
=
2
;
constexpr
auto
B_K0PerWmma
=
WmmaK
/
B_KRow
/
K1Number
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
N0
=
N
/
NPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
scale_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_K0PerWmma
>
{},
Number
<
B_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideC
));
}
}();
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
}
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
(
1
,
1
,
1
));
using
BGridDesc
=
decltype
(
MakeBGridDescriptor
(
1
,
1
,
1
));
using
ScaleGridDesc
=
decltype
(
MakeScaleGridDescriptor
(
1
,
1
,
0
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseFpAintBGemm_Wmma
<
BlockSize
,
ADataType
,
BDataType
,
ScaleDataType
,
AccDataType
,
CShuffleDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc
,
BGridDesc
,
ScaleGridDesc
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerWmma
,
NPerWmma
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
ScaleDataType
*
p_scale_grid
,
CDataType
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_scale_grid_
{
p_scale_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_
{},
b_grid_desc_
{},
scale_grid_desc_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
MRaw_
{
M
},
NRaw_
{
N
},
KRaw_
{
K
}
{
a_grid_desc_
=
DeviceOp
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
b_grid_desc_
=
DeviceOp
::
MakeBGridDescriptor
(
K
,
N
,
StrideB
);
scale_grid_desc_
=
DeviceOp
::
MakeScaleGridDescriptor
(
K
,
N
,
0
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_
,
b_grid_desc_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
ScaleDataType
*
p_scale_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc
a_grid_desc_
;
BGridDesc
b_grid_desc_
;
ScaleGridDesc
scale_grid_desc_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
// for checking vector load/store
index_t
MRaw_
;
index_t
NRaw_
;
index_t
KRaw_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I2
);
}
else
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I3
)
*
arg
.
a_grid_desc_
.
GetLength
(
I4
)
*
arg
.
a_grid_desc_
.
GetLength
(
I6
);
}
}();
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_fpAintB_gemm_wmma
<
GridwiseGemm
,
ADataType
,
BDataType
,
ScaleDataType
,
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc
>
,
remove_reference_t
<
DeviceOp
::
ScaleGridDesc
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_scale_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
scale_grid_desc_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_navi3_supported
())
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
printf
(
"DeviceOp err: AccDataType"
);
return
false
;
}
}
else
{
printf
(
"DeviceOp err: Arch"
);
return
false
;
}
// check vector load/store
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector store of C
// only support RowMajor for now
if
constexpr
(
is_same_v
<
CLayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
ScaleDataType
*
p_scale
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_scale
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_scale
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
weight_only
,
"weight_only"
}};
// clang-format off
str
<<
"DeviceFpAintBGemm_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
K1
<<
", "
<<
MPerWmma
<<
", "
<<
NPerWmma
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
">"
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"BEnableLds: "
<<
BEnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
9de63596
...
...
@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -27,21 +28,22 @@ template <typename ALayout,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K
0
PerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerW
MMA
,
ck
::
index_t
NPerW
MMA
,
ck
::
index_t
MPerW
mma
,
ck
::
index_t
NPerW
mma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
...
...
@@ -62,7 +64,6 @@ template <typename ALayout,
index_t
CShuffleNRepeatPerShuffle
,
typename
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceGemmMultipleD_Wmma_CShuffle
:
public
DeviceGemmMultipleD
<
ALayout
,
...
...
@@ -83,68 +84,139 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
static
constexpr
auto
AEnableLds_auto
=
(
NWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K
0
PerBlock
*
K1
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
// Describe how data read from Global memory
static
auto
MakeAGridDescriptor
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_m
raw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
const
auto
a_grid_desc_m
_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
else
if
constexpr
(
is_same
_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}();
const
auto
a_grid_desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
AEnableLds
)
{
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
constexpr
auto
A_KRow
=
2
;
constexpr
auto
A_K0PerWmma
=
WmmaK
/
A_KRow
/
K1Number
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_K0PerWmma
>
{},
Number
<
A_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
static
auto
MakeBGridDescriptor
_K0_N_K1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_n
raw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
const
auto
b_grid_desc_n
_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}();
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
BEnableLds
)
{
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
constexpr
auto
B_KRow
=
2
;
constexpr
auto
B_K0PerWmma
=
WmmaK
/
B_KRow
/
K1Number
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
N0
=
N
/
NPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_K0PerWmma
>
{},
Number
<
B_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
template
<
typename
ELayout_
>
...
...
@@ -180,13 +252,13 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc
_K0_M_K1
=
decltype
(
MakeAGridDescriptor
_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc
_K0_N_K1
=
decltype
(
MakeBGridDescriptor
_K0_N_K1
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
(
1
,
1
,
1
));
using
BGridDesc
=
decltype
(
MakeBGridDescriptor
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
// GridwiseOp
using
GridwiseOp
=
GridwiseGemmMultipleD_
k0mk1_k0nk1_mn_wmma_cshuffle
<
using
GridwiseOp
=
GridwiseGemmMultipleD_
Wmma
<
// DataType Family
ADataType
,
BDataType
,
...
...
@@ -195,8 +267,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
DsDataType
,
EDataType
,
// InMemory Data Descriptor
AGridDesc
_K0_M_K1
,
BGridDesc
_K0_N_K1
,
AGridDesc
,
BGridDesc
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
// ElementwiseOp Family
...
...
@@ -207,9 +279,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// Tiling Family
MPerBlock
,
NPerBlock
,
K
0
PerBlock
,
MPerW
MMA
,
NPerW
MMA
,
KPerBlock
,
MPerW
mma
,
NPerW
mma
,
K1
,
MRepeat
,
NRepeat
,
...
...
@@ -222,6 +294,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -230,6 +303,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
...
...
@@ -262,8 +336,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc
_k0_m_k1_
{},
b_grid_desc
_k0_n_k1_
{},
a_grid_desc
{},
b_grid_desc
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{},
...
...
@@ -278,8 +352,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
NRaw_
{
N
},
KRaw_
{
K
}
{
a_grid_desc
_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor
_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc
_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor
_K0_N_K1
(
K
,
N
,
StrideB
);
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
b_grid_desc
=
DeviceOp
::
MakeBGridDescriptor
(
K
,
N
,
StrideB
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
...
...
@@ -295,8 +369,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
block_2_ctile_map_
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
);
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
_k0_m_k1_
,
b_grid_desc
_k0_n_k1_
,
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b_grid_desc
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_ctile_map_
))
...
...
@@ -318,8 +392,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType
*
p_e_grid_
;
// Tensor Descriptors
AGridDesc
_K0_M_K1
a_grid_desc
_k0_m_k1_
;
BGridDesc
_K0_N_K1
b_grid_desc
_k0_n_k1_
;
AGridDesc
a_grid_desc
;
BGridDesc
b_grid_desc
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -352,24 +426,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
}
#endif
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
...
...
@@ -381,21 +439,27 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
const
auto
K
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I2
);
}
else
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
arg
.
a_grid_desc
.
GetLength
(
I4
)
*
arg
.
a_grid_desc
.
GetLength
(
I6
);
}
}();
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_mupltipe_d_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc
_K0_M_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc
_K0_N_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc
>
,
remove_reference_t
<
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
...
...
@@ -404,68 +468,35 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BElementwiseOperation
,
CDEElementwiseOperation
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
true
>
;
// Last Option is W/O
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
has_main_k_block_loop
>
;
// Last Option is W/O
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
kernel_gemm_mupltipe_d_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
...
...
@@ -575,8 +606,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
_k0_m_k1_
,
arg
.
b_grid_desc
_k0_n_k1_
,
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
...
...
@@ -681,14 +712,18 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K
0
PerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
K1
<<
", "
<<
MPerW
MMA
<<
", "
<<
NPerW
MMA
<<
", "
<<
MPerW
mma
<<
", "
<<
NPerW
mma
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
">"
<<
" NumPrefetch: "
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"BEnableLds: "
<<
BEnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
9de63596
...
...
@@ -498,94 +498,95 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
};
static
bool
IsSupported
Argument
(
const
Argument
&
arg
)
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
return
false
;
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
}
else
{
return
false
;
}
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
return
false
;
}
else
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of Ds
// only support RowMajor for now
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
// check vector load of Ds
// only support RowMajor for now
bool
all_valid
=
true
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
all_valid
=
false
;
}
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
(
!
all_valid
)
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
return
false
;
all_valid
=
false
;
}
});
// check vector store of E
// only support RowMajor for now
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
if
(
!
all_valid
)
{
return
false
;
}
// check vector store of E
// only support RowMajor for now
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
return
true
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
return
IsSupported
(
arg
.
MRaw_
,
arg
.
NRaw_
,
arg
.
KRaw_
)
and
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
...
@@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
str
.
str
();
}
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
struct
Descriptor
{
static
constexpr
auto
ds_tuple
()
{
return
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
DsDesc
{});
}
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ds_tuple
())
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{})))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{})))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_tuple
()))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k
;
BGridDesc_N_K
b_grid_desc_n_k
;
DsGridDesc_M_N
ds_grid_desc_m_n
;
EGridDesc_M_N
e_grid_desc_m_n
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
CDEElementwiseOperation
cde_element_op
;
// for checking vector load/store
index_t
MRaw
;
index_t
NRaw
;
index_t
KRaw
;
bool
has_main_k_block_loop
=
true
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CDEElementwiseOperation
cde_element_op_
)
:
a_grid_desc_m_k
{
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a
)},
b_grid_desc_n_k
{
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b
)},
ds_grid_desc_m_n
{
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
ds
)},
e_grid_desc_m_n
{
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
)},
a_grid_desc_ak0_m_ak1
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
)},
b_grid_desc_bk0_n_bk1
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
ds
))},
e_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
)},
block_2_etile_map
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
cde_element_op
{
cde_element_op_
},
MRaw
{
e
.
GetLength
(
I0
)},
NRaw
{
e
.
GetLength
(
I1
)},
KRaw
{
a
.
GetLength
(
I1
)}
{
}
constexpr
bool
IsValid
()
const
{
return
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
);
}
constexpr
index_t
GetBlockSize
()
const
{
return
BlockSize
;
}
constexpr
index_t
GetGridSize
()
const
{
return
block_2_etile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
}
};
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
CDEElementwiseOperation
cde_element_op
=
CDEElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
DsDesc
,
EDesc
>
(
a
,
b
,
ds
,
e
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
template
<
class
Desc
,
class
DsPointer
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
)
{
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
assert
(
desc
.
IsValid
());
if
(
desc
.
has_main_k_block_loop
)
{
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_etile_map
);
}
else
{
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_etile_map
);
}
}
};
}
// namespace device
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment