Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ae20247a
Commit
ae20247a
authored
Feb 29, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin' into aosewski/ggemm_multi_d2
parents
d1f7a3cf
a776978c
Changes
277
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
422 additions
and
148 deletions
+422
-148
example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp
...onvnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp
+11
-11
example/62_convnd_activ/run_convnd_activ_example.inc
example/62_convnd_activ/run_convnd_activ_example.inc
+19
-19
example/62_convnd_activ/unary/CMakeLists.txt
example/62_convnd_activ/unary/CMakeLists.txt
+35
-0
example/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp
...e/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp
+11
-11
example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp
example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp
...62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp
example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp
...e/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp
example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp
example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp
...ple/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp
...le/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp
+11
-0
example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp
example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp
+11
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+42
-27
include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp
...eration/gpu/device/impl/device_elementwise_scale_impl.hpp
+14
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+10
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+40
-13
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+93
-3
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+33
-51
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+15
-10
No files found.
example/62_conv
_fw
d_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp
→
example/62_conv
n
d_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp
View file @
ae20247a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <cstdlib>
#include <iostream>
#include <iostream>
...
@@ -100,16 +100,16 @@ template <ck::index_t NDimSpatial,
...
@@ -100,16 +100,16 @@ template <ck::index_t NDimSpatial,
typename
WeiElementOp
,
typename
WeiElementOp
,
typename
OutElementOp
,
typename
OutElementOp
,
typename
DeviceConvNDFwdInstance
>
typename
DeviceConvNDFwdInstance
>
bool
run_grouped_conv
_fwd
(
bool
do_verification
,
bool
run_grouped_conv
(
bool
do_verification
,
int
init_method
,
int
init_method
,
bool
time_kernel
,
bool
time_kernel
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
,
const
HostTensorDescriptor
&
in_g_n_c_wis_desc
,
const
HostTensorDescriptor
&
in_g_n_c_wis_desc
,
const
HostTensorDescriptor
&
wei_g_k_c_xs_desc
,
const
HostTensorDescriptor
&
wei_g_k_c_xs_desc
,
const
HostTensorDescriptor
&
out_g_n_k_wos_desc
,
const
HostTensorDescriptor
&
out_g_n_k_wos_desc
,
const
InElementOp
&
in_element_op
,
const
InElementOp
&
in_element_op
,
const
WeiElementOp
&
wei_element_op
,
const
WeiElementOp
&
wei_element_op
,
const
OutElementOp
&
out_element_op
)
const
OutElementOp
&
out_element_op
)
{
{
constexpr
ck
::
index_t
NumAs
=
2
;
constexpr
ck
::
index_t
NumAs
=
2
;
constexpr
ck
::
index_t
NumBs
=
2
;
constexpr
ck
::
index_t
NumBs
=
2
;
...
...
example/62_conv
_fw
d_activ/run_convnd_
fwd_
activ_example.inc
→
example/62_conv
n
d_activ/run_convnd_activ_example.inc
View file @
ae20247a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -11,7 +11,7 @@ void print_helper_msg()
...
@@ -11,7 +11,7 @@ void print_helper_msg()
<<
ck
::
utils
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
<<
ck
::
utils
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
}
}
bool
run_convnd_
fwd_
example
(
int
argc
,
char
*
argv
[])
bool
run_convnd_example
(
int
argc
,
char
*
argv
[])
{
{
print_helper_msg
();
print_helper_msg
();
...
@@ -63,23 +63,23 @@ bool run_convnd_fwd_example(int argc, char* argv[])
...
@@ -63,23 +63,23 @@ bool run_convnd_fwd_example(int argc, char* argv[])
ck
::
utils
::
conv
::
make_output_host_tensor_descriptor_g_n_k_wos_packed
<
OutLayout
>
(
ck
::
utils
::
conv
::
make_output_host_tensor_descriptor_g_n_k_wos_packed
<
OutLayout
>
(
conv_param
);
conv_param
);
return
run_grouped_conv
_fwd
<
NDimSpatial
,
return
run_grouped_conv
<
NDimSpatial
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
OutDataType
,
OutDataType
,
InElementOp
,
InElementOp
,
WeiElementOp
,
WeiElementOp
,
OutElementOp
,
OutElementOp
,
DeviceGroupedConvND
Fwd
ActivInstance
>
(
do_verification
,
DeviceGroupedConvNDActivInstance
>
(
do_verification
,
init_method
,
init_method
,
time_kernel
,
time_kernel
,
conv_param
,
conv_param
,
in_g_n_c_wis_desc
,
in_g_n_c_wis_desc
,
wei_g_k_c_xs_desc
,
wei_g_k_c_xs_desc
,
out_g_n_k_wos_desc
,
out_g_n_k_wos_desc
,
in_element_op
,
in_element_op
,
wei_element_op
,
wei_element_op
,
out_element_op
);
out_element_op
);
};
};
if
(
conv_param
.
num_dim_spatial_
==
3
)
if
(
conv_param
.
num_dim_spatial_
==
3
)
...
...
example/62_conv
_fw
d_activ/CMakeLists.txt
→
example/62_conv
n
d_activ/
unary/
CMakeLists.txt
View file @
ae20247a
...
@@ -2,48 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,48 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_convnd_
fwd_
activ_xdl
)
add_custom_target
(
example_convnd_activ_
unary_
xdl
)
# Sigmoid
# Sigmoid
add_example_executable
(
example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp
)
add_example_dependencies
(
example_convnd_
fwd_
activ_xdl example_convnd_fwd_xdl_sigmoid_fp16
)
add_example_dependencies
(
example_convnd_activ_
unary_
xdl example_convnd_fwd_xdl_sigmoid_fp16
)
# Tanh
# Tanh
add_example_executable
(
example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp
)
add_example_dependencies
(
example_convnd_
fwd_
activ_xdl example_convnd_fwd_xdl_tanh_fp16
)
add_example_dependencies
(
example_convnd_activ_
unary_
xdl example_convnd_fwd_xdl_tanh_fp16
)
# Relu
# Relu
add_example_executable
(
example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp
)
add_example_dependencies
(
example_convnd_
fwd_
activ_xdl example_convnd_fwd_xdl_relu_fp16
)
add_example_dependencies
(
example_convnd_activ_
unary_
xdl example_convnd_fwd_xdl_relu_fp16
)
# SoftRelu
# SoftRelu
add_example_executable
(
example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp
)
add_example_dependencies
(
example_convnd_
fwd_
activ_xdl example_convnd_fwd_xdl_softrelu_fp16
)
add_example_dependencies
(
example_convnd_activ_
unary_
xdl example_convnd_fwd_xdl_softrelu_fp16
)
# Abs
# Abs
add_example_executable
(
example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp
)
add_example_dependencies
(
example_convnd_
fwd_
activ_xdl example_convnd_fwd_xdl_abs_fp16
)
add_example_dependencies
(
example_convnd_activ_
unary_
xdl example_convnd_fwd_xdl_abs_fp16
)
# Pow
# Pow
add_example_executable
(
example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp
)
add_example_dependencies
(
example_convnd_
fwd_
activ_xdl example_convnd_fwd_xdl_pow_fp16
)
add_example_dependencies
(
example_convnd_activ_
unary_
xdl example_convnd_fwd_xdl_pow_fp16
)
# Clipped Relu
# Clipped Relu
add_example_executable
(
example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp
)
add_example_dependencies
(
example_convnd_
fwd_
activ_xdl example_convnd_fwd_xdl_clippedrelu_fp16
)
add_example_dependencies
(
example_convnd_activ_
unary_
xdl example_convnd_fwd_xdl_clippedrelu_fp16
)
# Leaky Relu
# Leaky Relu
add_example_executable
(
example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp
)
add_example_dependencies
(
example_convnd_
fwd_
activ_xdl example_convnd_fwd_xdl_leakyrelu_fp16
)
add_example_dependencies
(
example_convnd_activ_
unary_
xdl example_convnd_fwd_xdl_leakyrelu_fp16
)
# Elu
# Elu
add_example_executable
(
example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp
)
add_example_dependencies
(
example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16
)
add_example_dependencies
(
example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16
)
# ScaleAdd on A and B
add_example_executable
(
example_conv_fwd_xdl_scaleadd_ab_fp16 multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp
)
add_example_dependencies
(
example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp16
)
add_example_executable
(
example_conv_fwd_xdl_scaleadd_ab_fp32 multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp
)
add_example_dependencies
(
example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp32
)
add_example_executable
(
example_conv_fwd_xdl_scaleadd_ab_bf16 multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp
)
add_example_dependencies
(
example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_bf16
)
add_example_executable
(
example_conv_fwd_xdl_scaleadd_ab_int8 multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp
)
add_example_dependencies
(
example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_int8
)
# ScaleAdd ScaleAdd Relu
add_example_executable
(
example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp
)
add_example_dependencies
(
example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16
)
add_example_executable
(
example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp
)
add_example_dependencies
(
example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
example/62_conv
_fw
d_activ/convnd_fwd_activ_common.hpp
→
example/62_conv
n
d_activ/
unary/
convnd_fwd_activ_
unary_
common.hpp
View file @
ae20247a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -102,16 +102,16 @@ template <ck::index_t NDimSpatial,
...
@@ -102,16 +102,16 @@ template <ck::index_t NDimSpatial,
typename
WeiElementOp
,
typename
WeiElementOp
,
typename
OutElementOp
,
typename
OutElementOp
,
typename
DeviceConvNDFwdInstance
>
typename
DeviceConvNDFwdInstance
>
bool
run_grouped_conv
_fwd
(
bool
do_verification
,
bool
run_grouped_conv
(
bool
do_verification
,
int
init_method
,
int
init_method
,
bool
time_kernel
,
bool
time_kernel
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
,
const
HostTensorDescriptor
&
in_g_n_c_wis_desc
,
const
HostTensorDescriptor
&
in_g_n_c_wis_desc
,
const
HostTensorDescriptor
&
wei_g_k_c_xs_desc
,
const
HostTensorDescriptor
&
wei_g_k_c_xs_desc
,
const
HostTensorDescriptor
&
out_g_n_k_wos_desc
,
const
HostTensorDescriptor
&
out_g_n_k_wos_desc
,
const
InElementOp
&
in_element_op
,
const
InElementOp
&
in_element_op
,
const
WeiElementOp
&
wei_element_op
,
const
WeiElementOp
&
wei_element_op
,
const
OutElementOp
&
out_element_op
)
const
OutElementOp
&
out_element_op
)
{
{
Tensor
<
InDataType
>
in
(
in_g_n_c_wis_desc
);
Tensor
<
InDataType
>
in
(
in_g_n_c_wis_desc
);
Tensor
<
WeiDataType
>
wei
(
wei_g_k_c_xs_desc
);
Tensor
<
WeiDataType
>
wei
(
wei_g_k_c_xs_desc
);
...
...
example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp
0 → 100644
View file @
ae20247a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryAbs
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp
0 → 100644
View file @
ae20247a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
ClippedRelu
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp
0 → 100644
View file @
ae20247a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
Elu
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp
0 → 100644
View file @
ae20247a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
LeakyRelu
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp
0 → 100644
View file @
ae20247a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
Power
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp
0 → 100644
View file @
ae20247a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp
0 → 100644
View file @
ae20247a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
Sigmoid
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp
0 → 100644
View file @
ae20247a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
SoftRelu
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp
0 → 100644
View file @
ae20247a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
TanH
;
using
DeviceGroupedConvNDActivInstance
=
DeviceGroupedConvNDFwdInstance
<
OutElementOp
>
;
#include "../run_convnd_activ_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_example
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
ae20247a
...
@@ -37,7 +37,9 @@ template <index_t BlockSize,
...
@@ -37,7 +37,9 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
,
typename
ComputeTypeA
=
FloatA
,
typename
ComputeTypeB
=
FloatB
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatA
,
MPerXDL
,
NPerXDL
,
KPack
,
FloatB
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
ComputeTypeA
,
MPerXDL
,
NPerXDL
,
KPack
,
ComputeTypeB
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
...
@@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
A
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeType
A
>
(
a_thread_desc_
.
GetElementSpaceSize
());
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
B
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeType
B
>
(
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
@@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf
);
b_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
Float
A
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeType
A
,
KPack
>
a_thread_vec
;
vector_type
<
Float
B
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeType
B
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
Float
A
>()(
i
)
=
a_thread_buf
a_thread_vec
.
template
AsType
<
ComputeType
A
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
Float
B
>()(
i
)
=
b_thread_buf
b_thread_vec
.
template
AsType
<
ComputeType
B
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
});
});
using
mfma_input_type_a
=
using
mfma_input_type_a
=
typename
vector_type
<
Float
A
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeType
A
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
using
mfma_input_type_b
=
using
mfma_input_type_b
=
typename
vector_type
<
Float
B
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeType
B
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
@@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
Float
A
,
ComputeType
A
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
...
@@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1
>
;
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
Float
B
,
ComputeType
B
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
...
@@ -398,6 +401,8 @@ template <index_t BlockSize,
...
@@ -398,6 +401,8 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
,
index_t
KPack
,
typename
ComputeTypeA
=
FloatA
,
typename
ComputeTypeB
=
FloatB
,
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
>
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
>
struct
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
struct
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
:
public
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
:
public
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
...
@@ -410,7 +415,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -410,7 +415,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
ComputeTypeA
,
ComputeTypeB
>
{
{
using
Base
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
using
Base
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatA
,
FloatA
,
...
@@ -422,7 +429,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -422,7 +429,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
;
KPack
,
ComputeTypeA
,
ComputeTypeB
>
;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
...
@@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
A
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeType
A
>
(
a_thread_desc_
.
GetElementSpaceSize
());
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
B
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeType
B
>
(
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerThread
,
KPerInnerLoop
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPerThread
,
KPerInnerLoop
>
{}([
&
](
auto
k
)
{
...
@@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
Float
A
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeType
A
,
KPack
>
a_thread_vec
;
vector_type
<
Float
B
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeType
B
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
Float
A
>()(
i
)
=
a_thread_vec
.
template
AsType
<
ComputeType
A
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
0
,
0
,
k_
+
i
))
>
{}];
make_tuple
(
m0
,
0
,
0
,
k_
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
Float
B
>()(
i
)
=
b_thread_vec
.
template
AsType
<
ComputeType
B
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
0
,
0
,
k_
+
i
))
>
{}];
make_tuple
(
n0
,
0
,
0
,
k_
+
i
))
>
{}];
});
});
using
mfma_input_type_a
=
using
mfma_input_type_a
=
typename
vector_type
<
Float
A
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeType
A
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
using
mfma_input_type_b
=
using
mfma_input_type_b
=
typename
vector_type
<
Float
B
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeType
B
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
@@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
Float
A
,
ComputeType
A
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
...
@@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1
>
;
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
Float
B
,
ComputeType
B
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
...
@@ -586,7 +595,9 @@ template <index_t BlockSize,
...
@@ -586,7 +595,9 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
,
index_t
KPack
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
typename
ComputeTypeA
=
FloatA
,
typename
ComputeTypeB
=
FloatB
>
constexpr
auto
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
()
constexpr
auto
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
()
{
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
...
@@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
...
@@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
ComputeTypeA
,
ComputeTypeB
>
{};
}
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
{
...
@@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
...
@@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
ComputeTypeA
,
ComputeTypeB
>
{};
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp
View file @
ae20247a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -322,6 +322,19 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
...
@@ -322,6 +322,19 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
{
{
return
std
::
make_unique
<
Invoker
>
();
return
std
::
make_unique
<
Invoker
>
();
};
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceElementwiseNormalizationImpl<"
;
str
<<
NumDim
<<
", "
;
str
<<
MPerThread
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
// namespace device
};
// namespace device
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
ae20247a
...
@@ -60,7 +60,9 @@ template <typename ADataType,
...
@@ -60,7 +60,9 @@ template <typename ADataType,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
ComputeType
=
CDataType
,
typename
ComputeType
=
CDataType
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
typename
LDSTypeA
=
ComputeType
,
typename
LDSTypeB
=
ComputeType
>
struct
DeviceGemmXdlSplitKCShuffle
:
public
DeviceGemmSplitK
<
ALayout
,
struct
DeviceGemmXdlSplitKCShuffle
:
public
DeviceGemmSplitK
<
ALayout
,
BLayout
,
BLayout
,
...
@@ -81,6 +83,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -81,6 +83,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
// TODO: should be exposed as Tparams.
static
constexpr
index_t
NumGemmKPrefetchStage
=
1
;
static
constexpr
index_t
NumGemmKPrefetchStage
=
1
;
using
ComputeTypeA
=
ComputeType
;
using
ComputeTypeB
=
ComputeType
;
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
BlockSize
,
ADataType
,
ADataType
,
...
@@ -125,7 +130,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -125,7 +130,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
LoopSched
,
LoopSched
,
PipelineVer
,
PipelineVer
,
ComputeType
>
;
ComputeTypeA
,
ComputeTypeB
,
LDSTypeA
,
LDSTypeB
>
;
struct
Argument
:
public
GridwiseGemm
::
Argument
struct
Argument
:
public
GridwiseGemm
::
Argument
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
ae20247a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -650,32 +650,52 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -650,32 +650,52 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
constexpr
auto
AtomicAdd
=
InMemoryDataOperationEnum
::
AtomicAdd
;
constexpr
auto
AtomicAdd
=
InMemoryDataOperationEnum
::
AtomicAdd
;
constexpr
auto
Set
=
InMemoryDataOperationEnum
::
Set
;
constexpr
auto
Set
=
InMemoryDataOperationEnum
::
Set
;
if
(
arg
.
k_batch_
>
1
)
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
// in IsSupportedArgument function
if
constexpr
(
std
::
is_same
<
ADataType
,
ck
::
bhalf_t
>::
value
)
{
{
if
(
has_main_k_block_loop
)
if
(
has_main_k_block_loop
)
{
{
ave_time
=
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
Set
>
{});
integral_constant
<
InMemoryDataOperationEnum
,
AtomicAdd
>
{});
}
}
else
else
{
{
ave_time
=
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
Set
>
{});
integral_constant
<
InMemoryDataOperationEnum
,
AtomicAdd
>
{});
}
}
}
}
else
else
{
{
if
(
has_main_k_block_loop
)
if
(
arg
.
k_batch_
>
1
)
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
if
(
has_main_k_block_loop
)
integral_constant
<
InMemoryDataOperationEnum
,
Set
>
{});
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
AtomicAdd
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
AtomicAdd
>
{});
}
}
}
else
else
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
if
(
has_main_k_block_loop
)
integral_constant
<
InMemoryDataOperationEnum
,
Set
>
{});
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
Set
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
Set
>
{});
}
}
}
}
}
...
@@ -718,6 +738,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -718,6 +738,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
}
}
}
}
// For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd
// instruction that supports bf16 and we cannot use splitk because of that
if
constexpr
(
std
::
is_same
<
ADataType
,
ck
::
bhalf_t
>::
value
)
{
supported
=
supported
&
(
arg
.
k_batch_
==
1
);
}
return
supported
;
return
supported
;
}
}
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
ae20247a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -75,6 +75,15 @@ struct Add
...
@@ -75,6 +75,15 @@ struct Add
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x2_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
x0
+
x2_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
template
<
>
template
<
>
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()
<
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x0
,
const
int8_t
&
x1
)
const
operator
()
<
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x0
,
const
int8_t
&
x1
)
const
...
@@ -156,7 +165,7 @@ struct Subtract
...
@@ -156,7 +165,7 @@ struct Subtract
struct
Bilinear
struct
Bilinear
{
{
Bilinear
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
Bilinear
(
float
alpha
=
1.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
Y
,
typename
X0
,
typename
X1
>
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
,
const
X0
&
,
const
X1
&
)
const
;
__host__
__device__
constexpr
void
operator
()(
Y
&
,
const
X0
&
,
const
X1
&
)
const
;
...
@@ -175,6 +184,14 @@ struct Bilinear
...
@@ -175,6 +184,14 @@ struct Bilinear
y
=
alpha_
*
x0
+
beta_
*
x1
;
y
=
alpha_
*
x0
+
beta_
*
x1
;
};
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
int8_t
,
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x0
,
const
int8_t
&
x1
)
const
{
y
=
type_convert
<
int8_t
>
(
alpha_
*
type_convert
<
float
>
(
x0
)
+
beta_
*
type_convert
<
float
>
(
x1
));
};
template
<
>
template
<
>
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()
<
half_t
,
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
operator
()
<
half_t
,
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
...
@@ -212,7 +229,8 @@ struct Bilinear
...
@@ -212,7 +229,8 @@ struct Bilinear
__host__
__device__
constexpr
void
operator
()
<
std
::
int8_t
,
std
::
int32_t
,
std
::
int8_t
>
(
__host__
__device__
constexpr
void
operator
()
<
std
::
int8_t
,
std
::
int32_t
,
std
::
int8_t
>
(
std
::
int8_t
&
y
,
const
std
::
int32_t
&
x0
,
const
std
::
int8_t
&
x1
)
const
std
::
int8_t
&
y
,
const
std
::
int32_t
&
x0
,
const
std
::
int8_t
&
x1
)
const
{
{
y
=
type_convert
<
std
::
int8_t
>
(
x0
+
ck
::
type_convert
<
std
::
int32_t
>
(
x1
));
y
=
type_convert
<
int8_t
>
(
alpha_
*
type_convert
<
float
>
(
x0
)
+
beta_
*
type_convert
<
float
>
(
x1
));
};
};
float
alpha_
;
float
alpha_
;
...
@@ -264,6 +282,14 @@ struct AddRelu
...
@@ -264,6 +282,14 @@ struct AddRelu
y
=
a
>
0.0
f
?
a
:
0.0
f
;
y
=
a
>
0.0
f
?
a
:
0.0
f
;
};
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
>
(
bhalf_t
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
a
=
x0
+
type_convert
<
float
>
(
x1
);
y
=
a
>
type_convert
<
bhalf_t
>
(
0.0
f
)
?
a
:
type_convert
<
bhalf_t
>
(
0.0
f
);
};
template
<
>
template
<
>
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()
<
int
,
int
,
int8_t
>
(
int
&
y
,
const
int
&
x0
,
const
int8_t
&
x1
)
const
operator
()
<
int
,
int
,
int8_t
>
(
int
&
y
,
const
int
&
x0
,
const
int8_t
&
x1
)
const
...
@@ -354,6 +380,70 @@ struct AddFastGelu
...
@@ -354,6 +380,70 @@ struct AddFastGelu
e
=
type_convert
<
half_t
>
(
x1_f
);
e
=
type_convert
<
half_t
>
(
x1_f
);
}
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
>
(
bhalf_t
&
e
,
const
float
&
c
,
const
bhalf_t
&
d
)
const
{
const
float
x0_f
=
c
+
type_convert
<
float
>
(
d
);
float
x1_f
=
0
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
bhalf_t
>
(
x1_f
);
}
};
// E = Silu(C + D)
struct
AddSilu
{
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d
)
const
{
const
float
x
=
c
+
d
;
Silu
{}.
template
operator
()
<
float
>(
e
,
x
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d
)
const
{
const
half_t
x
=
c
+
d
;
Silu
{}.
template
operator
()
<
half_t
>(
e
,
x
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
float
,
half_t
>
(
half_t
&
e
,
const
float
&
c
,
const
half_t
&
d
)
const
{
const
float
x0_f
=
c
+
d
;
float
x1_f
=
0
;
Silu
{}.
template
operator
()
<
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
half_t
>
(
x1_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
>
(
bhalf_t
&
e
,
const
float
&
c
,
const
bhalf_t
&
d
)
const
{
const
float
x0_f
=
c
+
type_convert
<
float
>
(
d
);
float
x1_f
=
0
;
Silu
{}.
template
operator
()
<
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
bhalf_t
>
(
x1_f
);
}
};
};
}
// namespace element_wise
}
// namespace element_wise
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
ae20247a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -21,50 +21,11 @@ struct PassThroughPack2
...
@@ -21,50 +21,11 @@ struct PassThroughPack2
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
f8x2_t
&
y
,
const
ck
::
half2_t
&
x
)
const
{
// fake conversion
uint16_t
t
=
ck
::
bit_cast
<
uint32_t
>
(
x
);
y
=
ck
::
bit_cast
<
ck
::
f8x2_t
>
(
t
);
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
{
{
auto
t
=
type_convert
<
float2_t
>
(
x
);
auto
t
=
type_convert
<
float2_t
>
(
x
);
y
=
type_convert
<
half2_t
>
(
t
);
y
=
type_convert
<
half2_t
>
(
t
);
}
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
half2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
f8x2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
float2_t
&
y
,
const
ck
::
float2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
int8x2_t
&
y
,
const
ck
::
int8x2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf2_t
&
y
,
const
ck
::
bhalf2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
double2_t
&
y
,
const
ck
::
double2_t
&
x
)
const
{
y
=
x
;
}
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
};
struct
PassThrough
struct
PassThrough
...
@@ -156,6 +117,12 @@ struct PassThrough
...
@@ -156,6 +117,12 @@ struct PassThrough
y
=
type_convert
<
half_t
>
(
x
);
y
=
type_convert
<
half_t
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
int8_t
>
(
bhalf_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
bhalf_t
>
(
x
);
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int32_t
>
(
int8_t
&
y
,
const
int32_t
&
x
)
const
__host__
__device__
void
operator
()
<
int8_t
,
int32_t
>
(
int8_t
&
y
,
const
int32_t
&
x
)
const
{
{
...
@@ -452,27 +419,29 @@ struct FastGelu
...
@@ -452,27 +419,29 @@ struct FastGelu
template
<
>
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const
float
emu
=
exp
(
-
u
);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
y
=
x
*
cdf
;
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
}
}
// device code, use lower precision "__expf" and "rcp"
// device code, use lower precision "__expf" and "rcp"
template
<
>
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const
float
emu
=
__expf
(
-
u
);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__expf
(
u
);
#if !CK_WORKAROUND_SWDEV_383542
#if !CK_WORKAROUND_SWDEV_383542
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__frcp_rn
(
1.
f
+
emu
)
-
1.
f
)
;
y
=
x
*
__frcp_rn
(
1.
f
+
emu
);
#else
#else
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__ocml_native_recip_f32
(
1.
f
+
emu
)
-
1.
f
)
;
y
=
x
*
__ocml_native_recip_f32
(
1.
f
+
emu
);
#endif
#endif
y
=
x
*
cdf
;
}
}
template
<
>
template
<
>
...
@@ -551,6 +520,19 @@ struct Sigmoid
...
@@ -551,6 +520,19 @@ struct Sigmoid
};
};
};
};
struct
Silu
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
ck
::
half_t
>
||
is_same_v
<
T
,
int8_t
>
||
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
x
*
(
one
/
(
one
+
ck
::
math
::
exp
(
-
x
)));
};
};
struct
TanH
struct
TanH
{
{
template
<
typename
T
>
template
<
typename
T
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
ae20247a
...
@@ -9,7 +9,6 @@
...
@@ -9,7 +9,6 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
@@ -96,7 +95,10 @@ template <index_t BlockSize,
...
@@ -96,7 +95,10 @@ template <index_t BlockSize,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeType
=
FloatC
>
typename
ComputeTypeA
=
FloatC
,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
LDSTypeA
=
ComputeTypeA
,
typename
LDSTypeB
=
ComputeTypeB
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -430,7 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -430,7 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
c_block_size
=
constexpr
auto
c_block_size
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
ComputeType
),
return
math
::
max
(
a_block_space_size
*
sizeof
(
LDSTypeA
)
+
b_block_space_size
*
sizeof
(
LDSTypeB
),
c_block_size
*
sizeof
(
FloatC
));
c_block_size
*
sizeof
(
FloatC
));
}
}
...
@@ -785,7 +788,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -785,7 +788,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatA
,
FloatA
,
Compute
Type
,
LDS
Type
A
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -815,7 +818,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -815,7 +818,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatB
,
FloatB
,
Compute
Type
,
LDS
Type
B
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -845,8 +848,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -845,8 +848,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
ComputeType
,
// Compute
Type
A
LDS
TypeA
,
ComputeType
,
// Compute
Type
B
LDS
TypeB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
@@ -855,7 +858,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -855,7 +858,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
K1
,
K1
,
LoopSched
>
();
LoopSched
,
ComputeTypeA
,
ComputeTypeB
>
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
@@ -863,8 +868,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -863,8 +868,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
ComputeType
*
p_a_block
=
static_cast
<
Compute
Type
*>
(
p_shared_block
);
auto
p_a_block
=
reinterpret_cast
<
LDS
Type
A
*>
(
p_shared_block
);
ComputeType
*
p_b_block
=
static_cast
<
Compute
Type
*>
(
p_
shared
_block
)
+
a_block_space_size
;
auto
p_b_block
=
reinterpret_cast
<
LDS
Type
B
*>
(
p_
a
_block
+
a_block_space_size
)
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment