Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ac876c6f
Commit
ac876c6f
authored
Jul 02, 2022
by
Chao Liu
Browse files
gemm bilinear
parent
8ec065fd
Changes
56
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
214 additions
and
1850 deletions
+214
-1850
example/02_gemm_alpha_beta/CMakeLists.txt
example/02_gemm_alpha_beta/CMakeLists.txt
+0
-1
example/02_gemm_bilinear/CMakeLists.txt
example/02_gemm_bilinear/CMakeLists.txt
+1
-0
example/02_gemm_bilinear/README.md
example/02_gemm_bilinear/README.md
+6
-4
example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
+0
-0
example/CMakeLists.txt
example/CMakeLists.txt
+1
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
+17
-16
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+10
-10
include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp
include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp
+0
-45
include/ck/tensor_operation/gpu/device/device_gemm_bias_activation.hpp
...nsor_operation/gpu/device/device_gemm_bias_activation.hpp
+0
-45
include/ck/tensor_operation/gpu/device/device_gemm_bias_activation_add.hpp
..._operation/gpu/device/device_gemm_bias_activation_add.hpp
+0
-50
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
...peration/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
+0
-513
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
.../gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
+0
-520
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
.../device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
+0
-580
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+21
-3
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+3
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
...y/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
+10
-8
library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
...k/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
+97
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+8
-13
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
..._fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
+20
-20
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
..._fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
+20
-20
No files found.
example/02_gemm_alpha_beta/CMakeLists.txt
deleted
100644 → 0
View file @
8ec065fd
add_example_executable
(
example_gemm_alpha_beta_xdl_fp16 gemm_alpha_beta_xdl_fp16.cpp
)
example/02_gemm_bilinear/CMakeLists.txt
0 → 100644
View file @
ac876c6f
add_example_executable
(
example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp
)
example/02_gemm_
alpha_beta
/README.md
→
example/02_gemm_
bilinear
/README.md
View file @
ac876c6f
# Instructions for ```example_gemm_
xdl_alpha_beta
```
# Instructions for ```example_gemm_
bilinear_xdl_fp16
```
## Run ```example_gemm_
xdl_alpha_beta
```
## Run ```example_gemm_
bilinear_xdl_fp16
```
```
bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
./bin/example_gemm_xdl_alpha_beta 1 1 1 0.5 0.5
#arg3: time kernel (0=no, 1=yes)
#arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE
#arg11 to 12: alpha, beta
./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5
```
Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16)
```
...
...
example/02_gemm_
alpha_beta/gemm_alpha_beta
_xdl_fp16.cpp
→
example/02_gemm_
bilinear/gemm_bilinear
_xdl_fp16.cpp
View file @
ac876c6f
File moved
example/CMakeLists.txt
View file @
ac876c6f
...
...
@@ -22,7 +22,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endfunction
(
add_example_executable_no_testing EXAMPLE_NAME
)
add_subdirectory
(
01_gemm
)
add_subdirectory
(
02_gemm_
alpha_beta
)
add_subdirectory
(
02_gemm_
bilinear
)
add_subdirectory
(
03_gemm_bias_relu
)
add_subdirectory
(
04_gemm_add_add_fastgelu
)
add_subdirectory
(
06_conv2d_fwd_bias_relu
)
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
View file @
ac876c6f
...
...
@@ -23,22 +23,23 @@ template <typename ALayout,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB
,
ck
::
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
Batch
)
=
0
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB
,
ck
::
index_t
BatchStrideC
,
ck
::
index_t
Batch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
ac876c6f
...
...
@@ -344,12 +344,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
,
index_t
Batch
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
)
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
...
...
@@ -546,10 +546,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
,
index_t
Batch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
)
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
...
...
@@ -563,12 +563,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
Batch
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
,
Batch
};
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -586,10 +586,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
,
index_t
Batch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
)
override
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -603,12 +603,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
Batch
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
,
Batch
);
c_element_op
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp
deleted
100644 → 0
View file @
8ec065fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmBias
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmBiasPtr
=
std
::
unique_ptr
<
DeviceGemmBias
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_bias_activation.hpp
deleted
100644 → 0
View file @
8ec065fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmBiasActivation
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmBiasActivationPtr
=
std
::
unique_ptr
<
DeviceGemmBiasActivation
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_bias_activation_add.hpp
deleted
100644 → 0
View file @
8ec065fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP
#define DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP
#include <iostream>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmBiasActivationAdd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmBiasActivationAddPtr
=
std
::
unique_ptr
<
DeviceGemmBiasActivationAdd
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
deleted
100644 → 0
View file @
8ec065fd
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
deleted
100644 → 0
View file @
8ec065fd
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
deleted
100644 → 0
View file @
8ec065fd
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
ac876c6f
...
...
@@ -94,10 +94,11 @@ struct Subtract
}
};
struct
AlphaBetaAdd
struct
Bilinear
{
AlphaBetaAdd
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
Bilinear
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
#if 0
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
...
...
@@ -115,7 +116,6 @@ struct AlphaBetaAdd
y = type_convert<double>(alpha_) * x0 + type_convert<double>(beta_) * x1;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
...
...
@@ -123,6 +123,24 @@ struct AlphaBetaAdd
y = type_convert<half_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1));
};
#else
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
,
const
X0
&
,
const
X1
&
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
alpha_
*
x0
+
beta_
*
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
float
,
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
y
=
type_convert
<
half_t
>
(
alpha_
*
x0
+
beta_
*
ck
::
type_convert
<
float
>
(
x1
));
};
#endif
float
alpha_
;
float
beta_
;
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
ac876c6f
...
...
@@ -16,12 +16,14 @@ using F32 = float;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F16_TUPLE
=
ck
::
Tuple
<
F16
>
;
using
F16_F16_TUPLE
=
ck
::
Tuple
<
F16
,
F16
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
template
<
typename
DeviceOp
>
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
View file @
ac876c6f
...
...
@@ -25,7 +25,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instanc
Row
,
F16
,
F16
,
F16_F16
,
F16_F16
_TUPLE
,
F16
,
PassThrough
,
PassThrough
,
...
...
@@ -37,7 +37,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instanc
Row
,
F16
,
F16
,
F16_F16
,
F16_F16
_TUPLE
,
F16
,
PassThrough
,
PassThrough
,
...
...
@@ -49,7 +49,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instanc
Row
,
F16
,
F16
,
F16_F16
,
F16_F16
_TUPLE
,
F16
,
PassThrough
,
PassThrough
,
...
...
@@ -61,7 +61,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
Row
,
F16
,
F16
,
F16_F16
,
F16_F16
_TUPLE
,
F16
,
PassThrough
,
PassThrough
,
...
...
@@ -73,7 +73,8 @@ template <typename ALayout,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
D0DataType
,
typename
D1DataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
...
...
@@ -81,7 +82,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
DELayout
,
ADataType
,
BDataType
,
Ds
DataType
,
ck
::
Tuple
<
D0DataType
,
D1
DataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -92,7 +93,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
DELayout
,
ADataType
,
BDataType
,
Ds
DataType
,
ck
::
Tuple
<
D0DataType
,
D1
DataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -103,7 +104,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
DsDataType
,
Tuple
<
half_t
,
half_t
>>
&&
is_same_v
<
EDataType
,
half_t
>
)
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
D1DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
DELayout
,
Row
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
0 → 100644
View file @
ac876c6f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16_TUPLE
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
// GEMM + Bilinear
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
DDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
DELayout
,
Row
>
)
{}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
DELayout
,
Row
>
)
{
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
DELayout
,
Row
>
)
{
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
DELayout
,
Row
>
)
{
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
ac876c6f
...
...
@@ -5,15 +5,12 @@ function(add_instance_library INSTANCE_NAME)
set_target_properties
(
${
INSTANCE_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE ON
)
endfunction
(
add_instance_library INSTANCE_NAME
)
add_subdirectory
(
elementwise
)
add_subdirectory
(
gemm
)
add_subdirectory
(
gemm_splitk
)
add_subdirectory
(
gemm_bias2d
)
add_subdirectory
(
gemm_bias_relu
)
add_subdirectory
(
gemm_bias_relu_add
)
add_subdirectory
(
gemm_bilinear
)
add_subdirectory
(
gemm_add_add_fastgelu
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
gemm_bias_add_reduce
)
add_subdirectory
(
gemm_add_add_fastgelu
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
grouped_gemm
)
...
...
@@ -25,17 +22,16 @@ add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
conv2d_bwd_weight
)
add_subdirectory
(
normalization
)
add_subdirectory
(
reduce
)
add_subdirectory
(
normalization
)
add_subdirectory
(
elementwise
)
add_library
(
device_operations STATIC
$<TARGET_OBJECTS:device_gemm_instance>
$<TARGET_OBJECTS:device_gemm_splitk_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_gemm_bilinear_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
...
...
@@ -47,9 +43,9 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_elementwise_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_reduce_instance>
$<TARGET_OBJECTS:device_normalization_instance>
$<TARGET_OBJECTS:device_elementwise_instance>
)
add_library
(
composablekernels::device_operations ALIAS device_operations
)
...
...
@@ -81,7 +77,6 @@ target_include_directories(device_operations PUBLIC
#once new arches are enabled make this an option on the main cmake file
# and pass down here to be exported
target_compile_options
(
device_operations PRIVATE
--offload-arch=gfx908
--offload-arch=gfx90a
...
...
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
View file @
ac876c6f
...
...
@@ -14,9 +14,9 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F16_F16
_Tuple
=
ck
::
Tuple
<
F16
,
F16
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -38,22 +38,22 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
2
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
2
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
@@ -63,7 +63,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instanc
Row
,
F16
,
F16
,
F16_F16
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
...
...
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
View file @
ac876c6f
...
...
@@ -14,9 +14,9 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F16_F16
_Tuple
=
ck
::
Tuple
<
F16
,
F16
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -38,22 +38,22 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
8
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
8
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
@@ -63,7 +63,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
Row
,
F16
,
F16
,
F16_F16
,
F16_F16
_Tuple
,
F16
,
PassThrough
,
PassThrough
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment