Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
162d0305
Commit
162d0305
authored
Apr 22, 2024
by
root
Browse files
add client example
parent
d8ab41d5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
16 deletions
+13
-16
client_example/30_gemm_bf16Aint8B_add_fastgelu/CMakeLists.txt
...nt_example/30_gemm_bf16Aint8B_add_fastgelu/CMakeLists.txt
+3
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add_fastgelu.hpp
...sor_operation_instance/gpu/gemm_multiply_add_fastgelu.hpp
+2
-2
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i8_bf16_multi_d/device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn.hpp
...vice_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn.hpp
+8
-14
No files found.
client_example/30_gemm_bf16Aint8B_add_fastgelu/CMakeLists.txt
View file @
162d0305
...
...
@@ -2,6 +2,9 @@ if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf
add_executable
(
client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp
)
target_link_libraries
(
client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations
)
add_executable
(
client_gemm_multiply_add_fastgelu_xdl_bf16_i8 gemm_multiply_add_fastgelu_xdl_bf16_i8.cpp
)
target_link_libraries
(
client_gemm_multiply_add_fastgelu_xdl_bf16_i8 PRIVATE composable_kernel::device_gemm_operations
)
add_executable
(
client_gemm_bias_bf16_i8_bf16 gemm_bias_xdl_bf16_i8.cpp
)
target_link_libraries
(
client_gemm_bias_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add_fastgelu.hpp
View file @
162d0305
...
...
@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf
_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
int8
_t
>
&&
is_same_v
<
D0DataType
,
bhalf_t
>
&&
is_same_v
<
D1DataType
,
bhalf_t
>
&&
is_same_v
<
EDataType
,
bhalf_t
>
)
{
...
...
@@ -77,7 +77,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_
add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn
_instances
(
add_device_gemm_
xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_multiply_add_fastgelu_mnkpadding
_instances
(
op_ptrs
);
}
}
...
...
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i8_bf16_multi_d/device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn.hpp
View file @
162d0305
...
...
@@ -14,9 +14,9 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
BF16
=
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
Row
=
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -24,7 +24,7 @@ using Col = tensor_layout::gemm::ColumnMajor;
template
<
index_t
...
Is
>
using
S
=
Sequence
<
Is
...
>
;
using
PassThrough
=
element_wise
::
PassThrough
;
using
PassThrough
=
element_wise
::
PassThrough
;
using
MultiplyAddFastGelu
=
element_wise
::
MultiplyAddFastGelu
;
static
constexpr
auto
GemmDefault
=
GemmSpecialization
::
Default
;
...
...
@@ -37,11 +37,7 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
using
DsLayout
=
ck
::
Tuple
<
Row
,
Row
>
;
template
<
typename
DsDType
,
typename
CElementwiseOp
,
GemmSpecialization
GemmSpec
>
template
<
typename
DsDType
,
typename
CElementwiseOp
,
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances
=
std
::
tuple
<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
...
...
@@ -60,12 +56,10 @@ using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances = s
// clang-format on
>
;
template
<
typename
DsDType
,
typename
CElementwiseOp
,
GemmSpecialization
GemmSpec
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
>
template
<
typename
DsDType
,
typename
CElementwiseOp
,
GemmSpecialization
GemmSpec
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
>
using
device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_mem_instances
=
std
::
tuple
<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment