Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
LLama_fastertransformer
Commits
068fb458
Commit
068fb458
authored
Aug 25, 2023
by
liuhy
Browse files
修改ck代码适配gfx926
parent
acd8b8ea
Changes
27
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1082 additions
and
1018 deletions
+1082
-1018
3rdparty/composable_kernel/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp
.../12_elementwise_normalization/elementwise_layernorm2d.cpp
+175
-175
3rdparty/composable_kernel/client_example/README.md
3rdparty/composable_kernel/client_example/README.md
+4
-1
3rdparty/composable_kernel/cmake/googletest.cmake
3rdparty/composable_kernel/cmake/googletest.cmake
+0
-1
3rdparty/composable_kernel/compile.sh
3rdparty/composable_kernel/compile.sh
+1
-1
3rdparty/composable_kernel/example/36_sparse_embedding/CMakeLists.txt
...posable_kernel/example/36_sparse_embedding/CMakeLists.txt
+1
-1
3rdparty/composable_kernel/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp
...entwise_normalization/elementwise_layernorm_blockwise.cpp
+195
-195
3rdparty/composable_kernel/include/ck/ck.hpp
3rdparty/composable_kernel/include/ck/ck.hpp
+2
-2
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
+4
-4
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp
...operation/gpu/device/device_elementwise_normalization.hpp
+68
-68
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+2
-2
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+2
-2
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
...gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+1
-1
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
...de/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
+6
-1
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
.../grid/gridwise_elementwise_layernorm_welford_variance.hpp
+500
-500
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
...tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
+1
-8
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+80
-47
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp
+1
-1
3rdparty/composable_kernel/include/ck/utility/data_type.hpp
3rdparty/composable_kernel/include/ck/utility/data_type.hpp
+11
-0
3rdparty/composable_kernel/include/ck/utility/inner_product.hpp
...ty/composable_kernel/include/ck/utility/inner_product.hpp
+25
-6
3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
...gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
+3
-2
No files found.
3rdparty/composable_kernel/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp
View file @
068fb458
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iomanip>
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp"
#include "ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp"
using
ADataType
=
ck
::
half_t
;
// Input 1
using
ADataType
=
ck
::
half_t
;
// Input 1
using
BDataType
=
ck
::
half_t
;
// Input 2
using
BDataType
=
ck
::
half_t
;
// Input 2
using
XDataType
=
ck
::
half_t
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
XElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
XElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
YElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
YElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
2
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
NumReduceDim
=
1
;
struct
SimpleDeviceMem
struct
SimpleDeviceMem
{
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
void
*
p_mem_
;
};
};
int
main
()
int
main
()
{
{
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
ck
::
index_t
M
=
48
*
256
;
ck
::
index_t
M
=
48
*
256
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
Stride
=
N
;
ck
::
index_t
Stride
=
N
;
auto
mn_size
=
(
M
-
1
)
*
Stride
+
N
;
auto
mn_size
=
(
M
-
1
)
*
Stride
+
N
;
SimpleDeviceMem
a_dev_buf
(
sizeof
(
ADataType
)
*
mn_size
);
SimpleDeviceMem
a_dev_buf
(
sizeof
(
ADataType
)
*
mn_size
);
SimpleDeviceMem
b_dev_buf
(
sizeof
(
BDataType
)
*
mn_size
);
SimpleDeviceMem
b_dev_buf
(
sizeof
(
BDataType
)
*
mn_size
);
SimpleDeviceMem
gamma_dev_buf
(
sizeof
(
GammaDataType
)
*
N
);
SimpleDeviceMem
gamma_dev_buf
(
sizeof
(
GammaDataType
)
*
N
);
SimpleDeviceMem
beta_dev_buf
(
sizeof
(
BetaDataType
)
*
N
);
SimpleDeviceMem
beta_dev_buf
(
sizeof
(
BetaDataType
)
*
N
);
SimpleDeviceMem
y_dev_buf
(
sizeof
(
YDataType
)
*
mn_size
);
SimpleDeviceMem
y_dev_buf
(
sizeof
(
YDataType
)
*
mn_size
);
std
::
array
<
const
void
*
,
2
>
ab_input
=
{
a_dev_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
2
>
ab_input
=
{
a_dev_buf
.
GetDeviceBuffer
(),
b_dev_buf
.
GetDeviceBuffer
()};
b_dev_buf
.
GetDeviceBuffer
()};
std
::
vector
<
ck
::
index_t
>
abStride
=
{
Stride
,
1
};
std
::
vector
<
ck
::
index_t
>
abStride
=
{
Stride
,
1
};
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
abStrides
=
{
abStride
,
abStride
};
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
abStrides
=
{
abStride
,
abStride
};
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceElementwiseNormalization
<
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceElementwiseNormalization
<
ck
::
Tuple
<
ADataType
,
BDataType
>
,
ck
::
Tuple
<
ADataType
,
BDataType
>
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
AccDataType
,
AccDataType
,
YDataType
,
YDataType
,
XElementwiseOperation
,
XElementwiseOperation
,
YElementwiseOperation
,
YElementwiseOperation
,
Rank
,
Rank
,
NumReduceDim
>
;
NumReduceDim
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
std
::
string
best_op_name
;
bool
found
=
false
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
// profile device operation instances
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
abStrides
,
abStrides
,
{
0
,
1
},
// gammaStrides
{
0
,
1
},
// gammaStrides
{
0
,
1
},
// betaStrides
{
0
,
1
},
// betaStrides
{
Stride
,
1
},
// yStrides
{
Stride
,
1
},
// yStrides
{
1
},
// reduceDims
{
1
},
// reduceDims
1e-4
,
1e-4
,
ab_input
,
ab_input
,
gamma_dev_buf
.
GetDeviceBuffer
(),
gamma_dev_buf
.
GetDeviceBuffer
(),
beta_dev_buf
.
GetDeviceBuffer
(),
beta_dev_buf
.
GetDeviceBuffer
(),
y_dev_buf
.
GetDeviceBuffer
(),
y_dev_buf
.
GetDeviceBuffer
(),
XElementwiseOperation
{},
XElementwiseOperation
{},
YElementwiseOperation
{});
YElementwiseOperation
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
M
*
N
+
sizeof
(
BDataType
)
*
M
*
N
+
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
M
*
N
+
sizeof
(
BDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
sizeof
(
YDataType
)
*
M
*
N
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
<<
op_name
<<
std
::
endl
;
if
(
ave_time
<
best_ave_time
)
if
(
ave_time
<
best_ave_time
)
{
{
found
=
true
;
found
=
true
;
best_op_id
=
i
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_op_name
=
op_name
;
best_ave_time
=
ave_time
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_gb_per_sec
=
gb_per_sec
;
}
}
}
}
else
else
{
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_op_name
<<
std
::
endl
;
// run the best intance
// run the best intance
{
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
abStrides
,
abStrides
,
{
1
},
// gammaStrides
{
1
},
// gammaStrides
{
1
},
// betaStrides
{
1
},
// betaStrides
{
Stride
,
1
},
// yStrides
{
Stride
,
1
},
// yStrides
{
1
},
// reduceDims
{
1
},
// reduceDims
1e-4
,
1e-4
,
ab_input
,
ab_input
,
gamma_dev_buf
.
GetDeviceBuffer
(),
gamma_dev_buf
.
GetDeviceBuffer
(),
beta_dev_buf
.
GetDeviceBuffer
(),
beta_dev_buf
.
GetDeviceBuffer
(),
y_dev_buf
.
GetDeviceBuffer
(),
y_dev_buf
.
GetDeviceBuffer
(),
XElementwiseOperation
{},
XElementwiseOperation
{},
YElementwiseOperation
{});
YElementwiseOperation
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
}
return
0
;
return
0
;
}
}
3rdparty/composable_kernel/client_example/README.md
View file @
068fb458
...
@@ -9,7 +9,10 @@ cd client_example/build
...
@@ -9,7 +9,10 @@ cd client_example/build
```
```
```
bash
```
bash
cmake
-D
CMAKE_CXX_COMPILER
=
${
ROCM_PATH
}
/bin/hipcc
-D
CMAKE_PREFIX_PATH
=
"
${
ROCM_PATH
}
;
${
PATH_TO_CK_INSTALL_DIRECTORY
}
"
..
cmake
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_PREFIX_PATH
=
"/opt/rocm;
${
PATH_TO_CK_INSTALL_DIRECTORY
}
"
\
..
```
```
### Build client example
### Build client example
...
...
3rdparty/composable_kernel/cmake/googletest.cmake
View file @
068fb458
...
@@ -27,7 +27,6 @@ message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLA
...
@@ -27,7 +27,6 @@ message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLA
FetchContent_Declare
(
FetchContent_Declare
(
googletest
googletest
GIT_REPOSITORY http://10.0.50.24/Mirrors/googletest.git
GIT_REPOSITORY http://10.0.50.24/Mirrors/googletest.git
# GIT_REPOSITORY /work/home/zhangshao/installer/googletest
GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
)
)
...
...
3rdparty/composable_kernel/compile.sh
View file @
068fb458
...
@@ -7,7 +7,7 @@ cmake
...
@@ -7,7 +7,7 @@ cmake
-D
CMAKE_CXX_FLAGS
=
"-O3"
\
-D
CMAKE_CXX_FLAGS
=
"-O3"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
GPU_TARGETS
=
"gfx906;gfx926"
\
-D
GPU_TARGETS
=
"gfx906;gfx926"
\
-D
CMAKE_INSTALL_PREFIX
=
~/composable_kernel/install_ck
\
-D
CMAKE_INSTALL_PREFIX
=
~/composable_kernel
-develop
/install_ck
\
..
..
cd
-
cd
-
3rdparty/composable_kernel/example/36_sparse_embedding/CMakeLists.txt
View file @
068fb458
add_example_executable
(
example_sparse_embedding3_forward_layernorm sparse_embedding3_forward_layernorm.cpp
)
add_example_executable
(
example_sparse_embedding3_forward_layernorm sparse_embedding3_forward_layernorm.cpp
)
3rdparty/composable_kernel/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp
View file @
068fb458
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <getopt.h>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
using
ADataType
=
ck
::
half_t
;
// Input 1
using
ADataType
=
ck
::
half_t
;
// Input 1
using
BDataType
=
ck
::
half_t
;
// Input 2
using
BDataType
=
ck
::
half_t
;
// Input 2
using
XDataType
=
ck
::
half_t
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
XElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
XElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
YElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
YElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
2
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
NumReduceDim
=
1
;
// X = Elementwise(input1, input2, input3, ...)
// X = Elementwise(input1, input2, input3, ...)
// Y = Layernorm(X, beta, gamma)
// Y = Layernorm(X, beta, gamma)
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwiseNormalizationImpl
<
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwiseNormalizationImpl
<
ck
::
Tuple
<
ADataType
,
BDataType
>
,
ck
::
Tuple
<
ADataType
,
BDataType
>
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
AccDataType
,
AccDataType
,
YDataType
,
YDataType
,
XElementwiseOperation
,
XElementwiseOperation
,
YElementwiseOperation
,
YElementwiseOperation
,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
256
,
// BlockSize
256
,
// BlockSize
8
,
// ClusterM
8
,
// ClusterM
32
,
// ClusterK
32
,
// ClusterK
1
,
// SliceM
1
,
// SliceM
32
,
// SliceK
32
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
1
,
// SrcVecDim (0=M, 1=K)
8
,
// SrcScalarPerVector
8
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
1
,
// GammaVecDim (0=M, 1=K)
8
,
// GammaScalarPerVector
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
,
// BetaScalarPerVector
8
>
;
// OutScalarPerVector
8
>
;
// OutScalarPerVector
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
Functor
>
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
Functor
>
void
host_elementwise2D
(
HostTensorC
&
C
,
void
host_elementwise2D
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
const
HostTensorB
&
B
,
const
std
::
vector
<
std
::
size_t
>&
shape
,
const
std
::
vector
<
std
::
size_t
>&
shape
,
Functor
functor
)
Functor
functor
)
{
{
using
ctype
=
ck
::
remove_reference_t
<
decltype
(
C
(
0
,
0
))
>
;
using
ctype
=
ck
::
remove_reference_t
<
decltype
(
C
(
0
,
0
))
>
;
for
(
std
::
size_t
m
=
0
;
m
<
shape
[
0
];
++
m
)
for
(
std
::
size_t
m
=
0
;
m
<
shape
[
0
];
++
m
)
for
(
std
::
size_t
n
=
0
;
n
<
shape
[
1
];
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
shape
[
1
];
++
n
)
{
{
auto
a_val
=
A
(
m
,
n
);
auto
a_val
=
A
(
m
,
n
);
auto
b_val
=
B
(
m
,
n
);
auto
b_val
=
B
(
m
,
n
);
ctype
c_val
=
0
;
ctype
c_val
=
0
;
functor
(
c_val
,
a_val
,
b_val
);
functor
(
c_val
,
a_val
,
b_val
);
C
(
m
,
n
)
=
c_val
;
C
(
m
,
n
)
=
c_val
;
}
}
}
}
int
main
()
int
main
()
{
{
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
ck
::
index_t
M
=
48
*
256
;
ck
::
index_t
M
=
48
*
256
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
Stride
=
N
;
ck
::
index_t
Stride
=
N
;
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
std
::
vector
<
std
::
size_t
>
({
stride
}));
std
::
vector
<
std
::
size_t
>
({
stride
}));
};
};
auto
f_host_tensor_descriptor2d
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
)
{
auto
f_host_tensor_descriptor2d
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
};
};
Tensor
<
ADataType
>
a
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
ADataType
>
a
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
BDataType
>
b
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
BDataType
>
b
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
GammaDataType
>
gamma
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
GammaDataType
>
gamma
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
BetaDataType
>
beta
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
BetaDataType
>
beta
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
YDataType
>
y
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
YDataType
>
y
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
a
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
a
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
gamma
.
GenerateTensorValue
(
GeneratorTensor_2
<
GammaDataType
>
{
-
5
,
5
});
gamma
.
GenerateTensorValue
(
GeneratorTensor_2
<
GammaDataType
>
{
-
5
,
5
});
beta
.
GenerateTensorValue
(
GeneratorTensor_2
<
BetaDataType
>
{
-
5
,
5
});
beta
.
GenerateTensorValue
(
GeneratorTensor_2
<
BetaDataType
>
{
-
5
,
5
});
DeviceMem
a_dev
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_dev
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_dev
(
sizeof
(
BDataType
)
*
b
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_dev
(
sizeof
(
BDataType
)
*
b
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
beta_dev
(
sizeof
(
BetaDataType
)
*
beta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
beta_dev
(
sizeof
(
BetaDataType
)
*
beta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_dev
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_dev
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpaceSize
());
a_dev
.
ToDevice
(
a
.
mData
.
data
());
a_dev
.
ToDevice
(
a
.
mData
.
data
());
b_dev
.
ToDevice
(
b
.
mData
.
data
());
b_dev
.
ToDevice
(
b
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
beta_dev
.
ToDevice
(
beta
.
mData
.
data
());
beta_dev
.
ToDevice
(
beta
.
mData
.
data
());
std
::
array
<
const
void
*
,
2
>
input
=
{
a_dev
.
GetDeviceBuffer
(),
b_dev
.
GetDeviceBuffer
()};
std
::
array
<
const
void
*
,
2
>
input
=
{
a_dev
.
GetDeviceBuffer
(),
b_dev
.
GetDeviceBuffer
()};
auto
device_instance
=
DeviceInstance
{};
auto
device_instance
=
DeviceInstance
{};
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
{
M
,
N
},
{
M
,
N
},
{
{
std
::
vector
<
ck
::
index_t
>
{
a
.
mDesc
.
GetStrides
().
begin
(),
a
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
a
.
mDesc
.
GetStrides
().
begin
(),
a
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
b
.
mDesc
.
GetStrides
().
begin
(),
b
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
b
.
mDesc
.
GetStrides
().
begin
(),
b
.
mDesc
.
GetStrides
().
end
()},
},
},
{
0
,
1
},
{
0
,
1
},
{
0
,
1
},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
{
1
},
{
1
},
1e-4
,
1e-4
,
input
,
input
,
gamma_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
XElementwiseOperation
{},
XElementwiseOperation
{},
YElementwiseOperation
{});
YElementwiseOperation
{});
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
std
::
cout
<<
"The runtime parameters are not supported"
<<
std
::
endl
;
std
::
cout
<<
"The runtime parameters are not supported"
<<
std
::
endl
;
return
1
;
return
1
;
};
};
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
float
ela_time
=
0
;
float
ela_time
=
0
;
ela_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
ela_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
float
data_mem_size
=
M
*
N
*
sizeof
(
ADataType
)
+
M
*
N
*
sizeof
(
BDataType
)
+
float
data_mem_size
=
M
*
N
*
sizeof
(
ADataType
)
+
M
*
N
*
sizeof
(
BDataType
)
+
M
*
N
*
sizeof
(
YDataType
)
+
N
*
sizeof
(
GammaDataType
)
+
M
*
N
*
sizeof
(
YDataType
)
+
N
*
sizeof
(
GammaDataType
)
+
N
*
sizeof
(
BetaDataType
);
N
*
sizeof
(
BetaDataType
);
float
bandwidth
=
data_mem_size
*
1000
/
ela_time
/
1024
/
1024
/
1024
;
float
bandwidth
=
data_mem_size
*
1000
/
ela_time
/
1024
/
1024
/
1024
;
std
::
cout
<<
"Bandwidth is : "
<<
bandwidth
<<
"GB/s . "
<<
std
::
endl
;
std
::
cout
<<
"Bandwidth is : "
<<
bandwidth
<<
"GB/s . "
<<
std
::
endl
;
std
::
cout
<<
"Time elapase is : "
<<
ela_time
<<
" ms . "
<<
std
::
endl
;
std
::
cout
<<
"Time elapase is : "
<<
ela_time
<<
" ms . "
<<
std
::
endl
;
bool
pass
=
true
;
bool
pass
=
true
;
{
{
std
::
vector
<
std
::
size_t
>
mn
=
{
static_cast
<
unsigned
long
>
(
M
),
std
::
vector
<
std
::
size_t
>
mn
=
{
static_cast
<
unsigned
long
>
(
M
),
static_cast
<
unsigned
long
>
(
N
)};
static_cast
<
unsigned
long
>
(
N
)};
Tensor
<
XDataType
>
x
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
XDataType
>
x
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
host_elementwise2D
<
Tensor
<
ADataType
>
,
host_elementwise2D
<
Tensor
<
ADataType
>
,
Tensor
<
BDataType
>
,
Tensor
<
BDataType
>
,
Tensor
<
XDataType
>
,
Tensor
<
XDataType
>
,
XElementwiseOperation
>
(
x
,
a
,
b
,
mn
,
XElementwiseOperation
{});
XElementwiseOperation
>
(
x
,
a
,
b
,
mn
,
XElementwiseOperation
{});
Tensor
<
YDataType
>
host_y
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
YDataType
>
host_y
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
using
ReferenceInstance
=
using
ReferenceInstance
=
ck
::
tensor_operation
::
host
::
ReferenceLayernorm
<
XDataType
,
ck
::
tensor_operation
::
host
::
ReferenceLayernorm
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
AccDataType
,
AccDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Rank
,
Rank
,
NumReduceDim
>
;
NumReduceDim
>
;
ReferenceInstance
ref
;
ReferenceInstance
ref
;
auto
ref_argument
=
auto
ref_argument
=
ref
.
MakeArgument
(
x
,
gamma
,
beta
,
host_y
,
YElementwiseOperation
{},
{
M
,
N
},
{
1
},
1e-4
);
ref
.
MakeArgument
(
x
,
gamma
,
beta
,
host_y
,
YElementwiseOperation
{},
{
M
,
N
},
{
1
},
1e-4
);
auto
ref_invoker
=
ref
.
MakeInvoker
();
auto
ref_invoker
=
ref
.
MakeInvoker
();
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
pass
&=
pass
&=
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
);
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
);
if
(
!
(
pass
))
if
(
!
(
pass
))
{
{
std
::
cout
<<
"layernorm wrong"
<<
std
::
endl
;
std
::
cout
<<
"layernorm wrong"
<<
std
::
endl
;
}
}
}
}
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
}
}
3rdparty/composable_kernel/include/ck/ck.hpp
View file @
068fb458
...
@@ -33,7 +33,7 @@
...
@@ -33,7 +33,7 @@
// buffer resource
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)||defined(__gfx926__)
|| defined(__gfx908__) || \
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
||
defined(__gfx926__) || defined(__gfx908__) || \
defined(__gfx90a__) // for GPU code
defined(__gfx90a__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#elif defined(__gfx1030__) // for GPU code
...
@@ -46,7 +46,7 @@
...
@@ -46,7 +46,7 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32
#define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__)|| defined(__gfx926__) || defined(__gfx908__) || defined(__gfx90a__) || \
#elif defined(__gfx906__)
|| defined(__gfx926__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx1030__) // for GPU code
defined(__gfx1030__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
View file @
068fb458
...
@@ -225,13 +225,13 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
...
@@ -225,13 +225,13 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_bk0_bm0_bm1_bk1_
.
GetElementSpaceSize
());
a_thread_desc_bk0_bm0_bm1_bk1_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
B
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
A
>
(
b_thread_desc_bk0_bn0_bn1_bk1_
.
GetElementSpaceSize
());
b_thread_desc_bk0_bn0_bn1_bk1_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_contraction
=
constexpr
auto
threadwise_contraction
=
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
<
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
<
FloatA
,
FloatA
,
Float
B
,
Float
A
,
FloatC
,
FloatC
,
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
...
@@ -394,8 +394,8 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
...
@@ -394,8 +394,8 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatB
,
FloatB
,
// src
Float
B
,
Float
A
,
// dst
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
Sequence
<
BK0PerThread
,
1
,
BN1PerThreadBN11
,
BK1
>
,
// SliceLengths
Sequence
<
BK0PerThread
,
1
,
BN1PerThreadBN11
,
BK1
>
,
// SliceLengths
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp
View file @
068fb458
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
InDataTypeTuple
,
template
<
typename
InDataTypeTuple
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
YDataType
,
typename
YDataType
,
typename
XElementwiseOperation
,
typename
XElementwiseOperation
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
>
index_t
NumReduceDim
>
struct
DeviceElementwiseNormalization
:
public
BaseOperator
struct
DeviceElementwiseNormalization
:
public
BaseOperator
{
{
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumInput
>
inStridesArray
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataType
epsilon
,
AccDataType
epsilon
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
void
*
p_gamma
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_y
,
XElementwiseOperation
x_elementwise_op
,
XElementwiseOperation
x_elementwise_op
,
YElementwiseOperation
y_elementwise_op
)
=
0
;
YElementwiseOperation
y_elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
typename
InDataTypeTuple
,
template
<
typename
InDataTypeTuple
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
YDataType
,
typename
YDataType
,
typename
XElementwiseOperation
,
typename
XElementwiseOperation
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
>
index_t
NumReduceDim
>
using
DeviceElementwiseNormalizationPtr
=
using
DeviceElementwiseNormalizationPtr
=
std
::
unique_ptr
<
DeviceElementwiseNormalization
<
InDataTypeTuple
,
std
::
unique_ptr
<
DeviceElementwiseNormalization
<
InDataTypeTuple
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
AccDataType
,
AccDataType
,
YDataType
,
YDataType
,
XElementwiseOperation
,
XElementwiseOperation
,
YElementwiseOperation
,
YElementwiseOperation
,
Rank
,
Rank
,
NumReduceDim
>>
;
NumReduceDim
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
068fb458
...
@@ -134,7 +134,7 @@ __global__ void
...
@@ -134,7 +134,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)||defined(__gfx926__)
|| defined(__gfx1030__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)
||
defined(__gfx926__) || defined(__gfx1030__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -709,7 +709,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -709,7 +709,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
))
{
{
return
false
;
return
false
;
}
}
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
068fb458
...
@@ -106,7 +106,7 @@ __global__ void
...
@@ -106,7 +106,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)|| defined(__gfx926__) || defined(__gfx1030__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)
|| defined(__gfx926__) || defined(__gfx1030__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -600,7 +600,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
...
@@ -600,7 +600,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
))
{
{
return
false
;
return
false
;
}
}
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
View file @
068fb458
...
@@ -1391,7 +1391,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
...
@@ -1391,7 +1391,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
))
{
{
return
false
;
return
false
;
}
}
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
View file @
068fb458
...
@@ -205,6 +205,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -205,6 +205,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
...
@@ -364,6 +365,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -364,6 +365,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
...
@@ -390,6 +392,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -390,6 +392,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
...
@@ -416,6 +419,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -416,6 +419,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
...
@@ -442,6 +446,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -442,6 +446,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
...
@@ -483,7 +488,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -483,7 +488,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
)
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx926"
||
ck
::
get_device_name
()
==
"gfx1030"
)
{
{
return
GridwiseGemm
::
CheckValidity
(
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
View file @
068fb458
This diff is collapsed.
Click to expand it.
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
View file @
068fb458
...
@@ -117,18 +117,11 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
...
@@ -117,18 +117,11 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool
ret
=
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
return
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
))
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
))
&&
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
);
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
);
if
(
!
ret
){
std
::
cout
<<
"M="
<<
M
<<
" N="
<<
N
<<
" K0="
<<
K0
<<
" c_grid_desc_m_n[0]="
<<
c_grid_desc_m_n
.
GetLength
(
I0
)
<<
" c_grid_desc_m_n[1]="
<<
c_grid_desc_m_n
.
GetLength
(
I1
)
<<
" b_grid_desc_k0_n_k1[0]="
<<
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
<<
" b_grid_desc_k0_n_k1[2]="
<<
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)
<<
" a_grid_desc_k0_m_k1[2]="
<<
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
<<
" K1="
<<
K1
<<
" MPerBlock="
<<
MPerBlock
<<
" NPerBlock="
<<
NPerBlock
<<
" K0PerBlock="
<<
K0PerBlock
<<
std
::
endl
;
}
return
ret
;
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
068fb458
...
@@ -18,7 +18,8 @@
...
@@ -18,7 +18,8 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_K0_M0_M1_K1
,
typename
AGridDesc_K0_M0_M1_K1
,
typename
BGridDesc_K0_N0_N1_K1
,
typename
BGridDesc_K0_N0_N1_K1
,
...
@@ -30,23 +31,27 @@ __global__ void
...
@@ -30,23 +31,27 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_dl_v1r3
(
const
FloatA
B
*
__restrict__
p_a_grid
,
kernel_gemm_dl_v1r3
(
const
FloatA
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size_of_a
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByteToA
()
/
sizeof
(
FloatA
);
constexpr
index_t
shared_block_size_of_b
=
GridwiseGemm
::
GetSharedMemoryNumberOfByteToB
()
/
sizeof
(
FloatB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatA
p_shared_block_a
[
shared_block_size_of_a
];
__shared__
FloatB
p_shared_block_b
[
shared_block_size_of_b
];
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared_block
,
p_shared_block_a
,
p_shared_block_b
,
a_grid_desc_k0_m0_m1_k1
,
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
...
@@ -56,7 +61,8 @@ __global__ void
...
@@ -56,7 +61,8 @@ __global__ void
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
...
@@ -99,7 +105,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -99,7 +105,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
ToA
()
{
{
// TODO: change this. I think it needs multi-dimensional alignment
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -122,7 +128,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -122,7 +128,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr
auto
b_block_aligned_space_size
=
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_block_desc_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_aligned_space_size
+
b_block_aligned_space_size
)
*
sizeof
(
FloatAB
);
return
2
*
a_block_aligned_space_size
*
sizeof
(
FloatA
);
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByteToB
()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k_m
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
b_block_aligned_space_size
*
sizeof
(
FloatB
);
}
}
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
...
@@ -145,14 +177,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -145,14 +177,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
{
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
// M0 * N0
return
grid_size
;
return
grid_size
;
}
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k_block_loop
=
(
K0
+
K0PerBlock
)
/
(
2
*
K0PerBlock
)
>
1
;
const
bool
has_main_k_block_loop
=
(
K0
+
K0PerBlock
)
/
(
2
*
K0PerBlock
)
>
1
;
// K0 > K0PerBlock ???
return
has_main_k_block_loop
;
return
has_main_k_block_loop
;
}
}
...
@@ -170,7 +202,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -170,7 +202,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M1
=
Number
<
MPerBlock
>
{};
const
auto
M1
=
Number
<
MPerBlock
>
{};
// 128
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
a_grid_desc_k0_m0_m1_k1
=
const
auto
a_grid_desc_k0_m0_m1_k1
=
...
@@ -178,8 +210,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -178,8 +210,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple
(
make_pass_through_transform
(
K0
),
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
// K0, M, K1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
// K0, M0, M1, K1
return
a_grid_desc_k0_m0_m1_k1
;
return
a_grid_desc_k0_m0_m1_k1
;
}
}
...
@@ -190,7 +222,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -190,7 +222,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
N1
=
Number
<
NPerBlock
>
{};
// 128
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
const
auto
b_grid_desc_k0_n0_n1_k1
=
const
auto
b_grid_desc_k0_n0_n1_k1
=
...
@@ -198,8 +230,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -198,8 +230,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple
(
make_pass_through_transform
(
K0
),
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
// K0, N, K1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
// K0, N0, N1, K1
return
b_grid_desc_k0_n0_n1_k1
;
return
b_grid_desc_k0_n0_n1_k1
;
}
}
...
@@ -210,33 +242,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -210,33 +242,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
// 128
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
// 128
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M11
=
constexpr
auto
M11
=
// 64
Number
<
container_reduce
(
M11N11ThreadClusterM110Xs
{},
math
::
multiplies
{},
I1
)
*
Number
<
container_reduce
(
M11N11ThreadClusterM110Xs
{},
math
::
multiplies
{},
I1
)
*
// S<8, 2> ==> 8*2=16
M1PerThreadM111
>
{};
M1PerThreadM111
>
{};
// M1PerThread 4
constexpr
auto
N11
=
constexpr
auto
N11
=
// 64
Number
<
container_reduce
(
M11N11ThreadClusterN110Xs
{},
math
::
multiplies
{},
I1
)
*
Number
<
container_reduce
(
M11N11ThreadClusterN110Xs
{},
math
::
multiplies
{},
I1
)
*
// 16
N1PerThreadN111
>
{};
N1PerThreadN111
>
{};
// N1PerThread 4
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
M10
=
M1
/
M11
;
// 2
constexpr
auto
N10
=
N1
/
N11
;
constexpr
auto
N10
=
N1
/
N11
;
// 2
const
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
transform_tensor_descriptor
(
const
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
// M, N
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
// M0, M10, M11, N0, N10, N11
return
c_grid_desc_m0_m10_m11_n0_n10_n11
;
return
c_grid_desc_m0_m10_m11_n0_n10_n11
;
}
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
// what a fuck ???????????? 到底生成了啥
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
...
@@ -252,10 +284,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -252,10 +284,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
__device__
static
void
Run
(
const
FloatA
B
*
__restrict__
p_a_grid
,
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
FloatA
*
__restrict__
p_shared_block_a
,
FloatB
*
__restrict__
p_shared_block_b
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
&
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
&
c_grid_desc_m0_m10_m11_n0_n10_n11
,
...
@@ -304,12 +337,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -304,12 +337,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// TODO: check alignment
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// (16, 128, 2), 2
// TODO: check alignment
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
// B matrix in LDS memory, for blockwise GEMM
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// (16, 128, 2), 2
static_assert
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
()
==
static_assert
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
()
==
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
...
@@ -325,8 +358,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -325,8 +358,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatA
B
,
FloatA
,
FloatA
B
,
FloatA
,
remove_reference_t
<
decltype
(
a_grid_desc_k0_m0_m1_k1
)
>
,
remove_reference_t
<
decltype
(
a_grid_desc_k0_m0_m1_k1
)
>
,
decltype
(
a_block_desc_k0_m0_m1_k1
),
decltype
(
a_block_desc_k0_m0_m1_k1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -349,8 +382,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -349,8 +382,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
Float
A
B
,
FloatB
,
Float
A
B
,
FloatB
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
decltype
(
b_block_desc_k0_n0_n1_k1
),
decltype
(
b_block_desc_k0_n0_n1_k1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -368,14 +401,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -368,14 +401,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// b_mtx[K
0
PerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockSize
,
BlockSize
,
FloatA
B
,
FloatA
,
// todo split a/b
Float
A
B
,
FloatB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
@@ -400,8 +433,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -400,8 +433,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatA
B
*
p_a_block_double
=
p_shared_block
;
FloatA
*
p_a_block_double
=
p_shared_block
_a
;
Float
A
B
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
FloatB
*
p_b_block_double
=
p_shared_block
_b
;
// register allocation for output
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
...
@@ -436,7 +469,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -436,7 +469,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
{
{
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
// K / K1(=2)
index_t
k_block_data_begin
=
0
;
index_t
k_block_data_begin
=
0
;
...
@@ -487,7 +520,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -487,7 +520,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
K0PerBlock
;
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
// K0PerBlock = 16
}
}
// LDS double buffer: tail
// LDS double buffer: tail
...
...
3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp
View file @
068fb458
...
@@ -151,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
...
@@ -151,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
dst_origin_idx
+
data_to_origin_disp_idx
+
src_vector_idx
);
dst_origin_idx
+
data_to_origin_disp_idx
+
src_vector_idx
);
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
Dst
Data
>()[
Number
<
src_vector_offset
>
{}]);
src_vector
.
template
AsType
<
Src
Data
>()[
Number
<
src_vector_offset
>
{}]);
});
});
});
});
}
}
...
...
3rdparty/composable_kernel/include/ck/utility/data_type.hpp
View file @
068fb458
...
@@ -942,6 +942,10 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
...
@@ -942,6 +942,10 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// u8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
// Convert X to Y
// Convert X to Y
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
...
@@ -951,6 +955,13 @@ __host__ __device__ constexpr Y type_convert(X x)
...
@@ -951,6 +955,13 @@ __host__ __device__ constexpr Y type_convert(X x)
return
static_cast
<
Y
>
(
x
);
return
static_cast
<
Y
>
(
x
);
}
}
// Convert X to Y
template
<
>
__host__
__device__
constexpr
half_t
type_convert
<
half_t
,
uint8_t
>
(
uint8_t
x
)
{
return
static_cast
<
half_t
>
(
x
);
}
// convert bfp16 to fp32
// convert bfp16 to fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
...
...
3rdparty/composable_kernel/include/ck/utility/inner_product.hpp
View file @
068fb458
...
@@ -9,6 +9,31 @@ namespace ck {
...
@@ -9,6 +9,31 @@ namespace ck {
template
<
typename
TA
,
typename
TB
,
typename
TC
>
template
<
typename
TA
,
typename
TB
,
typename
TC
>
__device__
void
inner_product
(
const
TA
&
a
,
const
TB
&
b
,
TC
&
c
);
__device__
void
inner_product
(
const
TA
&
a
,
const
TB
&
b
,
TC
&
c
);
template
<
>
__device__
void
inner_product
<
half2_t
,
uint8x2_t
,
float
>
(
const
half2_t
&
a
,
const
uint8x2_t
&
b
,
float
&
c
)
{
const
vector_type
<
half_t
,
2
>
a_vector
{
a
};
const
vector_type
<
uint8_t
,
2
>
b_vector
{
b
};
const
vector_type
<
half_t
,
2
>
b_fp16_vector
;
static
constexpr
uint32_t
mask_for_elt_01
=
0x05020500
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x05030501
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
asm
volatile
(
"v_perm_b32 %0,%1,%2,%3;
\n
"
:
"=v"
(((
uint32_t
*
)
&
b_fp16_vector
.
data_
)[
0
])
:
"v"
(
start_byte_for_fp16
),
"v"
(((
uint32_t
*
)
&
b_vector
.
data_
)[
0
]),
"v"
(
mask_for_elt_01
));
// asm volatile("v_perm_b32 %0,%1,%2,%3;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[1]) : "v"(start_byte_for_fp16), "v"(((uint32_t*)&b_vector.data_)[0]), "v"(mask_for_elt_23));
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x6480
;
asm
volatile
(
"v_sub_f16x2 %0, %1, %2;
\n
"
:
"=v"
(((
uint32_t
*
)
&
b_fp16_vector
.
data_
)[
0
])
:
"v"
(((
uint32_t
*
)
&
b_fp16_vector
.
data_
)[
0
]),
"v"
(
I8s_TO_F16s_MAGIC_NUM
));
// asm volatile("v_sub_f16x2 %0, %1, %2;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[1]) : "v"(((uint32_t*)&b_fp16_vector.data_)[1]), "v"(I8s_TO_F16s_MAGIC_NUM));
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
c
+=
type_convert
<
int32_t
>
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
type_convert
<
int32_t
>
(
b_fp16_vector
.
AsType
<
half_t
>
()[
i
]);
});
}
template
<
>
template
<
>
__device__
void
inner_product
<
float
,
float
,
float
>
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
__device__
void
inner_product
<
float
,
float
,
float
>
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
{
{
...
@@ -71,12 +96,6 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
...
@@ -71,12 +96,6 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c
);
c
);
}
}
template
<
>
__device__
void
inner_product
<
half_t
,
half_t
,
float
>
(
const
half_t
&
a
,
const
half_t
&
b
,
float
&
c
)
{
c
+=
static_cast
<
float
>
(
a
*
b
);
}
template
<
>
template
<
>
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
{
...
...
3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
View file @
068fb458
...
@@ -15,6 +15,7 @@ namespace device {
...
@@ -15,6 +15,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
I8
=
int8_t
;
using
F32
=
float
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
@@ -34,13 +35,13 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple<
...
@@ -34,13 +35,13 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple<
// #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
DeviceGemmDl
<
F16
,
I8
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
I8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
instances
)
{
{
add_device_operation_instances
(
instances
,
device_gemm_dl_f16_f16_f16_km_kn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_dl_f16_f16_f16_km_kn_mn_instances
{});
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment