Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e573a2a0
Unverified
Commit
e573a2a0
authored
Jun 30, 2022
by
Chao Liu
Committed by
GitHub
Jun 30, 2022
Browse files
Merge branch 'develop' into batched_gemm_g_stride_fix
parents
6adf3591
0dcb3496
Changes
258
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
163 additions
and
199 deletions
+163
-199
profiler/include/profile_batched_gemm_reduce_impl.hpp
profiler/include/profile_batched_gemm_reduce_impl.hpp
+7
-8
profiler/include/profile_conv_bwd_weight_impl.hpp
profiler/include/profile_conv_bwd_weight_impl.hpp
+4
-4
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
+3
-3
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
+3
-3
profiler/include/profile_convnd_bwd_data_impl.hpp
profiler/include/profile_convnd_bwd_data_impl.hpp
+15
-16
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
+21
-22
profiler/include/profile_gemm_bias_2d_impl.hpp
profiler/include/profile_gemm_bias_2d_impl.hpp
+11
-12
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
+7
-8
profiler/include/profile_gemm_bias_relu_add_impl.hpp
profiler/include/profile_gemm_bias_relu_add_impl.hpp
+7
-8
profiler/include/profile_gemm_bias_relu_impl.hpp
profiler/include/profile_gemm_bias_relu_impl.hpp
+7
-8
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+18
-11
profiler/include/profile_gemm_reduce_impl.hpp
profiler/include/profile_gemm_reduce_impl.hpp
+7
-8
profiler/include/profile_gemm_splitk_impl.hpp
profiler/include/profile_gemm_splitk_impl.hpp
+16
-15
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+7
-9
profiler/include/profile_normalization_impl.hpp
profiler/include/profile_normalization_impl.hpp
+10
-10
profiler/include/profile_reduce_impl.hpp
profiler/include/profile_reduce_impl.hpp
+4
-4
profiler/src/profile_gemm_add_add_fastgelu.cpp
profiler/src/profile_gemm_add_add_fastgelu.cpp
+10
-16
script/docker-rocm4.1.sh
script/docker-rocm4.1.sh
+0
-14
script/docker-rocm4.3.1.sh
script/docker-rocm4.3.1.sh
+0
-14
test/conv2d_bwd_data/conv2d_bwd_data.cpp
test/conv2d_bwd_data/conv2d_bwd_data.cpp
+6
-6
No files found.
profiler/include/profile_batched_gemm_reduce_impl.hpp
View file @
e573a2a0
...
...
@@ -19,7 +19,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
...
...
@@ -44,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -208,8 +208,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmReduceNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmReduceNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -218,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances
(
gemm_ptrs
);
}
...
...
@@ -226,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances
(
gemm_ptrs
);
}
...
...
@@ -234,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances
(
gemm_ptrs
);
}
...
...
@@ -242,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances
(
gemm_ptrs
);
}
...
...
profiler/include/profile_conv_bwd_weight_impl.hpp
View file @
e573a2a0
...
...
@@ -18,7 +18,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_weight_
instance
{
namespace
instance
{
using
DeviceConvBwdWeightNoOpPtr
=
DeviceConvBwdWeightPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -31,7 +31,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
DeviceConvBwdWeightNoOpPtr
>&
);
}
// namespace
device_conv2d_bwd_weight_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -165,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification,
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_weight_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
conv_ptrs
);
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_weight_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
}
...
...
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
View file @
e573a2a0
...
...
@@ -17,7 +17,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_fwd_bias_activation_add_
instance
{
namespace
instance
{
using
DeviceConvFwdBiasReluAddPtr
=
DeviceConvFwdBiasActivationAddPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluAddPtr =
void
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdBiasReluAddPtr
>&
);
}
// namespace
device_conv2d_fwd_bias_activation_add_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -179,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_bias_activation_add_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
...
...
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
View file @
e573a2a0
...
...
@@ -17,7 +17,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_fwd_bias_activation_
instance
{
namespace
instance
{
using
DeviceConvFwdBiasReluPtr
=
DeviceConvFwdBiasActivationPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluPtr =
void
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdBiasReluPtr
>&
);
}
// namespace
device_conv2d_fwd_bias_activation_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -169,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_bias_activation_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
...
...
profiler/include/profile_convnd_bwd_data_impl.hpp
View file @
e573a2a0
...
...
@@ -22,7 +22,7 @@ using INT8 = int8_t;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_
instance
{
namespace
instance
{
using
DeviceConvBwdDataNoOpPtr
=
DeviceConvBwdDataPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std
::
vector
<
DeviceConvBwdDataNoOpPtr
>&
);
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
std
::
vector
<
DeviceConvBwdDataNoOpPtr
>&
);
}
// namespace
device_conv2d_bwd_data_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
profiler
{
using
DeviceConvBwdDataNoOpPtr
=
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_instance
::
DeviceConvBwdDataNoOpPtr
;
using
DeviceConvBwdDataNoOpPtr
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceConvBwdDataNoOpPtr
;
template
<
typename
InLayout
>
HostTensorDescriptor
get_input_host_tensor_descriptor
(
const
std
::
vector
<
std
::
size_t
>&
dims
,
...
...
@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr(
switch
(
num_dim_spatial
)
{
case
1
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
conv_ptrs
);
break
;
case
2
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
conv_ptrs
);
break
;
case
3
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
conv_ptrs
);
break
;
default:
break
;
...
...
@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr(
switch
(
num_dim_spatial
)
{
case
1
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
conv_ptrs
);
break
;
case
2
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
break
;
case
3
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
conv_ptrs
);
break
;
default:
break
;
...
...
@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr(
switch
(
num_dim_spatial
)
{
case
1
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
conv_ptrs
);
break
;
case
2
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
break
;
case
3
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
conv_ptrs
);
break
;
default:
break
;
...
...
@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr(
switch
(
num_dim_spatial
)
{
case
1
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
conv_ptrs
);
break
;
case
2
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
conv_ptrs
);
break
;
case
3
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
conv_ptrs
);
break
;
default:
break
;
...
...
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
View file @
e573a2a0
...
...
@@ -10,13 +10,12 @@
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/
device_
gemm_add_add_fastgelu
_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/host_tensor/host_conv.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
...
...
@@ -30,9 +29,7 @@ template <typename ADataType,
typename
EDataType
,
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
D1Layout
,
typename
ELayout
>
typename
DELayout
>
// assume Ds and E have same layout
bool
profile_gemm_add_add_fastgelu_impl
(
int
do_verification
,
int
init_method
,
bool
/*do_log*/
,
...
...
@@ -62,10 +59,10 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD0
,
D
0
Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD1
,
D
1
Layout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD0
,
D
E
Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD1
,
D
E
Layout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
D
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
D
ELayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
...
...
@@ -100,19 +97,21 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
const
auto
b_element_op
=
BElementOp
{};
const
auto
cde_element_op
=
CDEElementOp
{};
// add device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
get_device_gemm_add_add_fastgelu_instances
<
ADataType
,
BDataType
,
AccDataType
,
D0DataType
,
D1DataType
,
EDataType
,
ALayout
,
BLayout
,
D0Layout
,
D1Layout
,
ELayout
>
();
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
...
...
profiler/include/profile_gemm_bias_2d_impl.hpp
View file @
e573a2a0
...
...
@@ -17,7 +17,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
DeviceGemmAlphaBetaPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmBiasPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -48,7 +48,7 @@ void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(
void
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmAlphaBetaPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -159,8 +159,7 @@ void profile_gemm_bias_2d_impl(int do_verification,
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmAlphaBetaPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmAlphaBetaPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -169,28 +168,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
...
...
@@ -201,28 +200,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
}
...
...
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
View file @
e573a2a0
...
...
@@ -19,7 +19,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
...
...
@@ -45,7 +45,7 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f
void
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmBiasAddReduceNoOpPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -236,8 +236,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmBiasAddReduceNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmBiasAddReduceNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -246,7 +245,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -254,7 +253,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -262,7 +261,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -270,7 +269,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
...
...
profiler/include/profile_gemm_bias_relu_add_impl.hpp
View file @
e573a2a0
...
...
@@ -18,7 +18,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
DeviceGemmBiasReluAddPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmBiasActivationAddPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
void
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmBiasReluAddPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -158,8 +158,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
c1_m_n_device_buf
.
ToDevice
(
c1_m_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmBiasReluAddPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmBiasReluAddPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -168,7 +167,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -176,7 +175,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -184,7 +183,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -192,7 +191,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
...
...
profiler/include/profile_gemm_bias_relu_impl.hpp
View file @
e573a2a0
...
...
@@ -18,7 +18,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
DeviceGemmBiasReluPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmBiasActivationPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(
void
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmBiasReluPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -144,8 +144,7 @@ void profile_gemm_bias_relu_impl(int do_verification,
c0_n_device_buf
.
ToDevice
(
c0_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmBiasReluPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmBiasReluPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -154,28 +153,28 @@ void profile_gemm_bias_relu_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
...
...
profiler/include/profile_gemm_impl.hpp
View file @
e573a2a0
...
...
@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/
device_gemm_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/
gemm
.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
...
...
@@ -94,14 +94,21 @@ int profile_gemm_impl(int do_verification,
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
get_device_gemm_instances
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
();
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
if
(
op_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device GEMM instance found"
);
}
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
// Run reference GEMM
if
(
do_verification
)
...
...
@@ -141,9 +148,9 @@ int profile_gemm_impl(int do_verification,
StrideA
,
StrideB
,
StrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
,
c
k
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
a_element_op
,
b_element_op
,
c
_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
profiler/include/profile_gemm_reduce_impl.hpp
View file @
e573a2a0
...
...
@@ -19,7 +19,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
...
...
@@ -45,7 +45,7 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -204,8 +204,7 @@ bool profile_gemm_reduce_impl(int do_verification,
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmReduceNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmReduceNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -214,7 +213,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -222,7 +221,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -230,7 +229,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -238,7 +237,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
...
...
profiler/include/profile_gemm_splitk_impl.hpp
View file @
e573a2a0
...
...
@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/
device_
gemm_splitk
_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
...
...
@@ -95,20 +95,21 @@ bool profile_gemm_splitk_impl(int do_verification,
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
get_device_gemm_splitk_instances
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
();
if
(
op_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device operation instance found"
);
}
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmSplitK
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
// Run reference GEMM
if
(
do_verification
)
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
e573a2a0
...
...
@@ -20,7 +20,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_grouped_gemm_
instance
{
namespace
instance
{
using
DeviceGroupedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -36,7 +36,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
}
// namespace
device_grouped_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -171,9 +171,7 @@ void profile_grouped_gemm_impl(int do_verification,
}
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -182,28 +180,28 @@ void profile_grouped_gemm_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
...
...
profiler/include/profile_normalization_impl.hpp
View file @
e573a2a0
...
...
@@ -18,7 +18,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_normalization_
instance
{
namespace
instance
{
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
...
...
@@ -26,7 +26,7 @@ void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationP
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
}
// namespace
device_normalization_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -109,23 +109,23 @@ void profile_normalization_impl(int do_verification,
is_same
<
AccDataType
,
float
>::
value
)
{
if
(
in_length
.
size
()
==
3
)
tensor_operation
::
device
::
device_normalization
_instance
::
add_device_softmax_f16_f16_rank3_instances
(
instances
);
tensor_operation
::
device
::
instance
::
add_device_softmax_f16_f16_rank3
_instance
s
(
instances
);
if
(
in_length
.
size
()
==
4
)
tensor_operation
::
device
::
device_normalization
_instance
::
add_device_softmax_f16_f16_rank4_instances
(
instances
);
tensor_operation
::
device
::
instance
::
add_device_softmax_f16_f16_rank4
_instance
s
(
instances
);
}
else
if
constexpr
(
is_same
<
InDataType
,
float
>::
value
&&
is_same
<
OutDataType
,
float
>::
value
&&
is_same
<
AccDataType
,
float
>::
value
)
{
if
(
in_length
.
size
()
==
3
)
tensor_operation
::
device
::
device_normalization
_instance
::
add_device_softmax_f32_f32_rank3_instances
(
instances
);
tensor_operation
::
device
::
instance
::
add_device_softmax_f32_f32_rank3
_instance
s
(
instances
);
if
(
in_length
.
size
()
==
4
)
tensor_operation
::
device
::
device_normalization
_instance
::
add_device_softmax_f32_f32_rank4_instances
(
instances
);
tensor_operation
::
device
::
instance
::
add_device_softmax_f32_f32_rank4
_instance
s
(
instances
);
}
}
...
...
profiler/include/profile_reduce_impl.hpp
View file @
e573a2a0
...
...
@@ -16,7 +16,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
template
<
int
Rank
,
int
NumReduceDim
,
int
ReduceOpId
,
bool
PropagateNan
,
bool
UseIndex
>
struct
ReduceDescription
...
...
@@ -91,7 +91,7 @@ bool description_match(const DescriptionType& description,
return
(
result
);
};
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -142,7 +142,7 @@ bool profile_reduce_impl_impl(bool do_verification,
float
beta
)
{
using
namespace
ck
::
tensor_operation
::
device
;
using
namespace
ck
::
tensor_operation
::
device
::
device_reduce_
instance
;
using
namespace
ck
::
tensor_operation
::
device
::
instance
;
using
ck
::
host_common
::
dumpBufferToFile
;
constexpr
bool
op_support_indices
=
...
...
@@ -464,7 +464,7 @@ bool profile_reduce_impl(bool do_verification,
bool
pass
=
true
;
using
tuple_of_description_instances
=
tensor_operation
::
device
::
device_reduce_
instance
::
reduce_description_instances
;
tensor_operation
::
device
::
instance
::
reduce_description_instances
;
const
auto
tuple_object
=
tuple_of_description_instances
{};
...
...
profiler/src/profile_gemm_add_add_fastgelu.cpp
View file @
e573a2a0
...
...
@@ -75,9 +75,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
auto
e_type
,
auto
a_layout
,
auto
b_layout
,
auto
d0_layout
,
auto
d1_layout
,
auto
e_layout
)
{
auto
de_layout
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
AccDataType
=
decltype
(
acc_type
);
...
...
@@ -87,15 +85,13 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
using
ALayout
=
decltype
(
a_layout
);
using
BLayout
=
decltype
(
b_layout
);
using
D0Layout
=
decltype
(
d0_layout
);
using
D1Layout
=
decltype
(
d1_layout
);
using
ELayout
=
decltype
(
e_layout
);
using
DELayout
=
decltype
(
de_layout
);
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideD0
=
ck
::
is_same_v
<
D
0
Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideD1
=
ck
::
is_same_v
<
D
1
Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideE
=
ck
::
is_same_v
<
ELayout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideD0
=
ck
::
is_same_v
<
D
E
Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideD1
=
ck
::
is_same_v
<
D
E
Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideE
=
ck
::
is_same_v
<
D
ELayout
,
Row
>
?
N
:
M
;
bool
pass
=
ck
::
profiler
::
profile_gemm_add_add_fastgelu_impl
<
ADataType
,
BDataType
,
...
...
@@ -105,9 +101,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
EDataType
,
ALayout
,
BLayout
,
D0Layout
,
D1Layout
,
ELayout
>
(
DELayout
>
(
do_verification
,
init_method
,
do_log
,
...
...
@@ -126,22 +120,22 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
MK_KN_MN_MN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Row
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
MK_NK_MN_MN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Col
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
KM_KN_MN_MN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Col
{},
Row
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
KM_NK_MN_MN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Col
{},
Col
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Col
{},
Col
{},
Row
{});
}
else
{
...
...
script/docker-rocm4.1.sh
deleted
100755 → 0
View file @
6adf3591
WORKSPACE
=
$1
echo
"workspace: "
$WORKSPACE
docker run
\
-it
\
--rm
\
--privileged
\
--group-add
sudo
\
-w
/root/workspace
\
-v
$WORKSPACE
:/root/workspace
\
rocm/tensorflow:rocm4.1-tf1.15-dev
\
/bin/bash
#--network host \
script/docker-rocm4.3.1.sh
deleted
100755 → 0
View file @
6adf3591
WORKSPACE
=
$1
echo
"workspace: "
$WORKSPACE
docker run
\
-it
\
--rm
\
--privileged
\
--group-add
sudo
\
-w
/root/workspace
\
-v
$WORKSPACE
:/root/workspace
\
rocm/tensorflow:rocm4.3.1-tf2.6-dev
\
/bin/bash
#--network host \
test/conv2d_bwd_data/conv2d_bwd_data.cpp
View file @
e573a2a0
...
...
@@ -20,7 +20,7 @@ using INT8 = int8_t;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_
instance
{
namespace
instance
{
using
DeviceConvBwdDataNoOpPtr
=
DeviceConvBwdDataPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -36,7 +36,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
DeviceConvBwdDataNoOpPtr
>&
);
}
// namespace
device_conv2d_bwd_data_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -220,28 +220,28 @@ int main(int argc, char* argv[])
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
conv_ptrs
);
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ck
::
bhalf_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
bhalf_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
bhalf_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
int8_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
int8_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
int8_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
conv_ptrs
);
}
...
...
Prev
1
…
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment