Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
LLama_fastertransformer
Commits
068fb458
Commit
068fb458
authored
Aug 25, 2023
by
liuhy
Browse files
修改ck代码适配gfx926
parent
acd8b8ea
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1082 additions
and
1018 deletions
+1082
-1018
3rdparty/composable_kernel/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp
.../12_elementwise_normalization/elementwise_layernorm2d.cpp
+175
-175
3rdparty/composable_kernel/client_example/README.md
3rdparty/composable_kernel/client_example/README.md
+4
-1
3rdparty/composable_kernel/cmake/googletest.cmake
3rdparty/composable_kernel/cmake/googletest.cmake
+0
-1
3rdparty/composable_kernel/compile.sh
3rdparty/composable_kernel/compile.sh
+1
-1
3rdparty/composable_kernel/example/36_sparse_embedding/CMakeLists.txt
...posable_kernel/example/36_sparse_embedding/CMakeLists.txt
+1
-1
3rdparty/composable_kernel/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp
...entwise_normalization/elementwise_layernorm_blockwise.cpp
+195
-195
3rdparty/composable_kernel/include/ck/ck.hpp
3rdparty/composable_kernel/include/ck/ck.hpp
+2
-2
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
+4
-4
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp
...operation/gpu/device/device_elementwise_normalization.hpp
+68
-68
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+2
-2
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+2
-2
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
...gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+1
-1
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
...de/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
+6
-1
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
.../grid/gridwise_elementwise_layernorm_welford_variance.hpp
+500
-500
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
...tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
+1
-8
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+80
-47
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp
+1
-1
3rdparty/composable_kernel/include/ck/utility/data_type.hpp
3rdparty/composable_kernel/include/ck/utility/data_type.hpp
+11
-0
3rdparty/composable_kernel/include/ck/utility/inner_product.hpp
...ty/composable_kernel/include/ck/utility/inner_product.hpp
+25
-6
3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
...gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
+3
-2
No files found.
3rdparty/composable_kernel/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp
View file @
068fb458
3rdparty/composable_kernel/client_example/README.md
View file @
068fb458
...
@@ -9,7 +9,10 @@ cd client_example/build
...
@@ -9,7 +9,10 @@ cd client_example/build
```
```
```
bash
```
bash
cmake
-D
CMAKE_CXX_COMPILER
=
${
ROCM_PATH
}
/bin/hipcc
-D
CMAKE_PREFIX_PATH
=
"
${
ROCM_PATH
}
;
${
PATH_TO_CK_INSTALL_DIRECTORY
}
"
..
cmake
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_PREFIX_PATH
=
"/opt/rocm;
${
PATH_TO_CK_INSTALL_DIRECTORY
}
"
\
..
```
```
### Build client example
### Build client example
...
...
3rdparty/composable_kernel/cmake/googletest.cmake
View file @
068fb458
...
@@ -27,7 +27,6 @@ message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLA
...
@@ -27,7 +27,6 @@ message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLA
FetchContent_Declare
(
FetchContent_Declare
(
googletest
googletest
GIT_REPOSITORY http://10.0.50.24/Mirrors/googletest.git
GIT_REPOSITORY http://10.0.50.24/Mirrors/googletest.git
# GIT_REPOSITORY /work/home/zhangshao/installer/googletest
GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
)
)
...
...
3rdparty/composable_kernel/compile.sh
View file @
068fb458
...
@@ -7,7 +7,7 @@ cmake
...
@@ -7,7 +7,7 @@ cmake
-D
CMAKE_CXX_FLAGS
=
"-O3"
\
-D
CMAKE_CXX_FLAGS
=
"-O3"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
GPU_TARGETS
=
"gfx906;gfx926"
\
-D
GPU_TARGETS
=
"gfx906;gfx926"
\
-D
CMAKE_INSTALL_PREFIX
=
~/composable_kernel/install_ck
\
-D
CMAKE_INSTALL_PREFIX
=
~/composable_kernel
-develop
/install_ck
\
..
..
cd
-
cd
-
3rdparty/composable_kernel/example/36_sparse_embedding/CMakeLists.txt
View file @
068fb458
3rdparty/composable_kernel/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp
View file @
068fb458
3rdparty/composable_kernel/include/ck/ck.hpp
View file @
068fb458
...
@@ -33,7 +33,7 @@
...
@@ -33,7 +33,7 @@
// buffer resource
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)||defined(__gfx926__)
|| defined(__gfx908__) || \
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
||
defined(__gfx926__) || defined(__gfx908__) || \
defined(__gfx90a__) // for GPU code
defined(__gfx90a__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#elif defined(__gfx1030__) // for GPU code
...
@@ -46,7 +46,7 @@
...
@@ -46,7 +46,7 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32
#define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__)|| defined(__gfx926__) || defined(__gfx908__) || defined(__gfx90a__) || \
#elif defined(__gfx906__)
|| defined(__gfx926__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx1030__) // for GPU code
defined(__gfx1030__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
View file @
068fb458
...
@@ -225,13 +225,13 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
...
@@ -225,13 +225,13 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_bk0_bm0_bm1_bk1_
.
GetElementSpaceSize
());
a_thread_desc_bk0_bm0_bm1_bk1_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
B
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
A
>
(
b_thread_desc_bk0_bn0_bn1_bk1_
.
GetElementSpaceSize
());
b_thread_desc_bk0_bn0_bn1_bk1_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_contraction
=
constexpr
auto
threadwise_contraction
=
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
<
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
<
FloatA
,
FloatA
,
Float
B
,
Float
A
,
FloatC
,
FloatC
,
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
...
@@ -394,8 +394,8 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
...
@@ -394,8 +394,8 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatB
,
FloatB
,
// src
Float
B
,
Float
A
,
// dst
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
Sequence
<
BK0PerThread
,
1
,
BN1PerThreadBN11
,
BK1
>
,
// SliceLengths
Sequence
<
BK0PerThread
,
1
,
BN1PerThreadBN11
,
BK1
>
,
// SliceLengths
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp
View file @
068fb458
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
068fb458
...
@@ -134,7 +134,7 @@ __global__ void
...
@@ -134,7 +134,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)||defined(__gfx926__)
|| defined(__gfx1030__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)
||
defined(__gfx926__) || defined(__gfx1030__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -709,7 +709,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -709,7 +709,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
))
{
{
return
false
;
return
false
;
}
}
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
068fb458
...
@@ -106,7 +106,7 @@ __global__ void
...
@@ -106,7 +106,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)|| defined(__gfx926__) || defined(__gfx1030__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)
|| defined(__gfx926__) || defined(__gfx1030__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -600,7 +600,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
...
@@ -600,7 +600,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
))
{
{
return
false
;
return
false
;
}
}
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
View file @
068fb458
...
@@ -1391,7 +1391,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
...
@@ -1391,7 +1391,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
))
{
{
return
false
;
return
false
;
}
}
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
View file @
068fb458
...
@@ -205,6 +205,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -205,6 +205,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
...
@@ -364,6 +365,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -364,6 +365,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
...
@@ -390,6 +392,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -390,6 +392,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
...
@@ -416,6 +419,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -416,6 +419,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
...
@@ -442,6 +446,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -442,6 +446,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
...
@@ -483,7 +488,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -483,7 +488,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
)
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
)
{
{
return
GridwiseGemm
::
CheckValidity
(
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
View file @
068fb458
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
View file @
068fb458
...
@@ -117,18 +117,11 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
...
@@ -117,18 +117,11 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool
ret
=
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
return
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
))
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
))
&&
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
);
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
);
if
(
!
ret
){
std
::
cout
<<
"M="
<<
M
<<
" N="
<<
N
<<
" K0="
<<
K0
<<
" c_grid_desc_m_n[0]="
<<
c_grid_desc_m_n
.
GetLength
(
I0
)
<<
" c_grid_desc_m_n[1]="
<<
c_grid_desc_m_n
.
GetLength
(
I1
)
<<
" b_grid_desc_k0_n_k1[0]="
<<
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
<<
" b_grid_desc_k0_n_k1[2]="
<<
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)
<<
" a_grid_desc_k0_m_k1[2]="
<<
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
<<
" K1="
<<
K1
<<
" MPerBlock="
<<
MPerBlock
<<
" NPerBlock="
<<
NPerBlock
<<
" K0PerBlock="
<<
K0PerBlock
<<
std
::
endl
;
}
return
ret
;
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
068fb458
...
@@ -18,7 +18,8 @@
...
@@ -18,7 +18,8 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_K0_M0_M1_K1
,
typename
AGridDesc_K0_M0_M1_K1
,
typename
BGridDesc_K0_N0_N1_K1
,
typename
BGridDesc_K0_N0_N1_K1
,
...
@@ -30,23 +31,27 @@ __global__ void
...
@@ -30,23 +31,27 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_dl_v1r3
(
const
FloatA
B
*
__restrict__
p_a_grid
,
kernel_gemm_dl_v1r3
(
const
FloatA
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size_of_a
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByteToA
()
/
sizeof
(
FloatA
);
constexpr
index_t
shared_block_size_of_b
=
GridwiseGemm
::
GetSharedMemoryNumberOfByteToB
()
/
sizeof
(
FloatB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatA
p_shared_block_a
[
shared_block_size_of_a
];
__shared__
FloatB
p_shared_block_b
[
shared_block_size_of_b
];
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared_block
,
p_shared_block_a
,
p_shared_block_b
,
a_grid_desc_k0_m0_m1_k1
,
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
...
@@ -56,7 +61,8 @@ __global__ void
...
@@ -56,7 +61,8 @@ __global__ void
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
...
@@ -99,7 +105,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -99,7 +105,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
ToA
()
{
{
// TODO: change this. I think it needs multi-dimensional alignment
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -122,7 +128,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -122,7 +128,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr
auto
b_block_aligned_space_size
=
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_block_desc_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_aligned_space_size
+
b_block_aligned_space_size
)
*
sizeof
(
FloatAB
);
return
2
*
a_block_aligned_space_size
*
sizeof
(
FloatA
);
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByteToB
()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k_m
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
b_block_aligned_space_size
*
sizeof
(
FloatB
);
}
}
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
...
@@ -145,14 +177,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -145,14 +177,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
{
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
// M0 * N0
return
grid_size
;
return
grid_size
;
}
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k_block_loop
=
(
K0
+
K0PerBlock
)
/
(
2
*
K0PerBlock
)
>
1
;
const
bool
has_main_k_block_loop
=
(
K0
+
K0PerBlock
)
/
(
2
*
K0PerBlock
)
>
1
;
// K0 > K0PerBlock ???
return
has_main_k_block_loop
;
return
has_main_k_block_loop
;
}
}
...
@@ -170,7 +202,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -170,7 +202,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M1
=
Number
<
MPerBlock
>
{};
const
auto
M1
=
Number
<
MPerBlock
>
{};
// 128
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
a_grid_desc_k0_m0_m1_k1
=
const
auto
a_grid_desc_k0_m0_m1_k1
=
...
@@ -178,8 +210,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -178,8 +210,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple
(
make_pass_through_transform
(
K0
),
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
// K0, M, K1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
// K0, M0, M1, K1
return
a_grid_desc_k0_m0_m1_k1
;
return
a_grid_desc_k0_m0_m1_k1
;
}
}
...
@@ -190,7 +222,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -190,7 +222,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
N1
=
Number
<
NPerBlock
>
{};
// 128
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
const
auto
b_grid_desc_k0_n0_n1_k1
=
const
auto
b_grid_desc_k0_n0_n1_k1
=
...
@@ -198,8 +230,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -198,8 +230,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple
(
make_pass_through_transform
(
K0
),
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
// K0, N, K1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
// K0, N0, N1, K1
return
b_grid_desc_k0_n0_n1_k1
;
return
b_grid_desc_k0_n0_n1_k1
;
}
}
...
@@ -210,33 +242,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -210,33 +242,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
// 128
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
// 128
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M11
=
constexpr
auto
M11
=
// 64
Number
<
container_reduce
(
M11N11ThreadClusterM110Xs
{},
math
::
multiplies
{},
I1
)
*
Number
<
container_reduce
(
M11N11ThreadClusterM110Xs
{},
math
::
multiplies
{},
I1
)
*
// S<8, 2> ==> 8*2=16
M1PerThreadM111
>
{};
M1PerThreadM111
>
{};
// M1PerThread 4
constexpr
auto
N11
=
constexpr
auto
N11
=
// 64
Number
<
container_reduce
(
M11N11ThreadClusterN110Xs
{},
math
::
multiplies
{},
I1
)
*
Number
<
container_reduce
(
M11N11ThreadClusterN110Xs
{},
math
::
multiplies
{},
I1
)
*
// 16
N1PerThreadN111
>
{};
N1PerThreadN111
>
{};
// N1PerThread 4
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
M10
=
M1
/
M11
;
// 2
constexpr
auto
N10
=
N1
/
N11
;
constexpr
auto
N10
=
N1
/
N11
;
// 2
const
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
transform_tensor_descriptor
(
const
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
// M, N
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
// M0, M10, M11, N0, N10, N11
return
c_grid_desc_m0_m10_m11_n0_n10_n11
;
return
c_grid_desc_m0_m10_m11_n0_n10_n11
;
}
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
// what a fuck ???????????? 到底生成了啥
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
...
@@ -252,10 +284,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -252,10 +284,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
__device__
static
void
Run
(
const
FloatA
B
*
__restrict__
p_a_grid
,
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
FloatA
*
__restrict__
p_shared_block_a
,
FloatB
*
__restrict__
p_shared_block_b
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
&
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
&
c_grid_desc_m0_m10_m11_n0_n10_n11
,
...
@@ -304,12 +337,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -304,12 +337,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// TODO: check alignment
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// (16, 128, 2), 2
// TODO: check alignment
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
// B matrix in LDS memory, for blockwise GEMM
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// (16, 128, 2), 2
static_assert
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
()
==
static_assert
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
()
==
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
...
@@ -325,8 +358,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -325,8 +358,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatA
B
,
FloatA
,
FloatA
B
,
FloatA
,
remove_reference_t
<
decltype
(
a_grid_desc_k0_m0_m1_k1
)
>
,
remove_reference_t
<
decltype
(
a_grid_desc_k0_m0_m1_k1
)
>
,
decltype
(
a_block_desc_k0_m0_m1_k1
),
decltype
(
a_block_desc_k0_m0_m1_k1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -349,8 +382,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -349,8 +382,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
Float
A
B
,
FloatB
,
Float
A
B
,
FloatB
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
decltype
(
b_block_desc_k0_n0_n1_k1
),
decltype
(
b_block_desc_k0_n0_n1_k1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -368,14 +401,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -368,14 +401,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// b_mtx[K
0
PerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockSize
,
BlockSize
,
FloatA
B
,
FloatA
,
// todo split a/b
Float
A
B
,
FloatB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
@@ -400,8 +433,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -400,8 +433,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatA
B
*
p_a_block_double
=
p_shared_block
;
FloatA
*
p_a_block_double
=
p_shared_block
_a
;
Float
A
B
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
FloatB
*
p_b_block_double
=
p_shared_block
_b
;
// register allocation for output
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
...
@@ -436,7 +469,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -436,7 +469,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
{
{
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
// K / K1(=2)
index_t
k_block_data_begin
=
0
;
index_t
k_block_data_begin
=
0
;
...
@@ -487,7 +520,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -487,7 +520,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
K0PerBlock
;
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
// K0PerBlock = 16
}
}
// LDS double buffer: tail
// LDS double buffer: tail
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp
View file @
068fb458
...
@@ -151,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
...
@@ -151,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
dst_origin_idx
+
data_to_origin_disp_idx
+
src_vector_idx
);
dst_origin_idx
+
data_to_origin_disp_idx
+
src_vector_idx
);
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
Dst
Data
>()[
Number
<
src_vector_offset
>
{}]);
src_vector
.
template
AsType
<
Src
Data
>()[
Number
<
src_vector_offset
>
{}]);
});
});
});
});
}
}
...
...
3rdparty/composable_kernel/include/ck/utility/data_type.hpp
View file @
068fb458
...
@@ -942,6 +942,10 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
...
@@ -942,6 +942,10 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// u8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
// Convert X to Y
// Convert X to Y
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
...
@@ -951,6 +955,13 @@ __host__ __device__ constexpr Y type_convert(X x)
...
@@ -951,6 +955,13 @@ __host__ __device__ constexpr Y type_convert(X x)
return
static_cast
<
Y
>
(
x
);
return
static_cast
<
Y
>
(
x
);
}
}
// Convert X to Y
template
<
>
__host__
__device__
constexpr
half_t
type_convert
<
half_t
,
uint8_t
>
(
uint8_t
x
)
{
return
static_cast
<
half_t
>
(
x
);
}
// convert bfp16 to fp32
// convert bfp16 to fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
...
...
3rdparty/composable_kernel/include/ck/utility/inner_product.hpp
View file @
068fb458
...
@@ -9,6 +9,31 @@ namespace ck {
...
@@ -9,6 +9,31 @@ namespace ck {
template
<
typename
TA
,
typename
TB
,
typename
TC
>
template
<
typename
TA
,
typename
TB
,
typename
TC
>
__device__
void
inner_product
(
const
TA
&
a
,
const
TB
&
b
,
TC
&
c
);
__device__
void
inner_product
(
const
TA
&
a
,
const
TB
&
b
,
TC
&
c
);
template
<
>
__device__
void
inner_product
<
half2_t
,
uint8x2_t
,
float
>
(
const
half2_t
&
a
,
const
uint8x2_t
&
b
,
float
&
c
)
{
const
vector_type
<
half_t
,
2
>
a_vector
{
a
};
const
vector_type
<
uint8_t
,
2
>
b_vector
{
b
};
const
vector_type
<
half_t
,
2
>
b_fp16_vector
;
static
constexpr
uint32_t
mask_for_elt_01
=
0x05020500
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x05030501
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
asm
volatile
(
"v_perm_b32 %0,%1,%2,%3;
\n
"
:
"=v"
(((
uint32_t
*
)
&
b_fp16_vector
.
data_
)[
0
])
:
"v"
(
start_byte_for_fp16
),
"v"
(((
uint32_t
*
)
&
b_vector
.
data_
)[
0
]),
"v"
(
mask_for_elt_01
));
// asm volatile("v_perm_b32 %0,%1,%2,%3;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[1]) : "v"(start_byte_for_fp16), "v"(((uint32_t*)&b_vector.data_)[0]), "v"(mask_for_elt_23));
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x6480
;
asm
volatile
(
"v_sub_f16x2 %0, %1, %2;
\n
"
:
"=v"
(((
uint32_t
*
)
&
b_fp16_vector
.
data_
)[
0
])
:
"v"
(((
uint32_t
*
)
&
b_fp16_vector
.
data_
)[
0
]),
"v"
(
I8s_TO_F16s_MAGIC_NUM
));
// asm volatile("v_sub_f16x2 %0, %1, %2;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[1]) : "v"(((uint32_t*)&b_fp16_vector.data_)[1]), "v"(I8s_TO_F16s_MAGIC_NUM));
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
c
+=
type_convert
<
int32_t
>
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
type_convert
<
int32_t
>
(
b_fp16_vector
.
AsType
<
half_t
>
()[
i
]);
});
}
template
<
>
template
<
>
__device__
void
inner_product
<
float
,
float
,
float
>
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
__device__
void
inner_product
<
float
,
float
,
float
>
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
{
{
...
@@ -71,12 +96,6 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
...
@@ -71,12 +96,6 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c
);
c
);
}
}
template
<
>
__device__
void
inner_product
<
half_t
,
half_t
,
float
>
(
const
half_t
&
a
,
const
half_t
&
b
,
float
&
c
)
{
c
+=
static_cast
<
float
>
(
a
*
b
);
}
template
<
>
template
<
>
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
{
...
...
3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
View file @
068fb458
...
@@ -15,6 +15,7 @@ namespace device {
...
@@ -15,6 +15,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
I8
=
int8_t
;
using
F32
=
float
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
@@ -34,13 +35,13 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple<
...
@@ -34,13 +35,13 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple<
// #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
DeviceGemmDl
<
F16
,
I8
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
I8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
instances
)
{
{
add_device_operation_instances
(
instances
,
device_gemm_dl_f16_f16_f16_km_kn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_dl_f16_f16_f16_km_kn_mn_instances
{});
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment