Commit 11279540 authored by Astha Rai's avatar Astha Rai
Browse files

Merge branch 'transpose_5d' of github.com:ROCmSoftwarePlatform/composable_kernel into transpose_5d

parents 14daa201 33e78b9a
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_reduce_xdl)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
if(result EQUAL 0)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
endif()
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
if(result EQUAL 0)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
endif()
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
endif()
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
endif()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
set(target 1)
endif()
endforeach()
\ No newline at end of file
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_reduce_xdl)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
set(target 1)
endif()
endforeach()
......@@ -2,7 +2,7 @@
## Run ```example_reduce_blockwise```
```bash
# -D <xxx> : input 3d/4d/5d tensor lengths
# -D <xxx> : input 3D/4D/5D tensor lengths
# -R <xxx> : reduce dimension ids
# -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)
......@@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
## Run ```example_reduce_multiblock_atomic_add```
```bash
# -D <xxx> : input 3d/4d/5d tensor lengths
# -D <xxx> : input 3D/4D/5D tensor lengths
# -R <xxx> : reduce dimension ids
# -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp32, 1: fp64)
......
add_custom_target(example_grouped_gemm_xdl)
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
endif()
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16)
endif()
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16)
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16)
endif()
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16)
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16)
endif()
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16)
endif()
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16)
endif()
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16)
add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
endif()
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
endif()
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
endif()
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_reduce_xdl)
add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
endif()
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
endif()
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
endif()
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
endif()
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
endif()
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
endif()
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
endif()
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
endif()
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
endif()
add_dependencies(example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
endif()
endif()
set(target 1)
endif()
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_reduce_xdl)
add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
add_example_dependencies(example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
endif()
set(target 1)
endif()
endforeach()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
endif()
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
endif()
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
set(target 1)
endif()
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
set(target 1)
endif()
endif()
set(target 1)
endif()
endforeach()
add_custom_target(example_grouped_conv_bwd_weight_dl)
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
endif()
add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
......@@ -46,25 +46,21 @@ struct CommonLayoutSetting
using OutputLayout = OutputLay;
};
template <ck::index_t NDimSpatial>
struct CommonLayoutSettingSelector;
namespace ctl = ck::tensor_layout::convolution;
template <>
struct CommonLayoutSettingSelector<1> final : CommonLayoutSetting<ctl::GNWC, ctl::GKXC, ctl::GNWK>
{
};
template <>
struct CommonLayoutSettingSelector<2> final
: CommonLayoutSetting<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>
{
};
template <>
struct CommonLayoutSettingSelector<3> final
: CommonLayoutSetting<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>
template <ck::index_t NDimSpatial>
struct CommonLayoutSettingSelector
: CommonLayoutSetting<ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>>
{
};
......@@ -84,10 +80,10 @@ struct ExecutionConfig final
bool time_kernel = false;
};
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
2, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
3, 4, 1, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, { 1, 1, 1 } \
}
inline void print_help_msg()
......
......@@ -76,4 +76,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
using InDataType = F16;
using WeiDataType = F16;
using OutDataType = F16;
using AccDataType = F32;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
NDimSpatial,
ck::tensor_layout::convolution::GNDHWC,
ck::tensor_layout::convolution::GKZYXC,
ck::tensor_layout::convolution::GNDHWK,
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
16, // MPerWMMA
16, // NPerWMMA
4, // MRepeat
2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
4,
2,
S<1, 32, 1, 8>,
1>;
template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
......@@ -78,4 +78,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
......@@ -77,4 +77,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
......@@ -83,4 +83,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
......@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param)
{
// Dl op doesn't support split_k > 1
// Dl and WMMA ops don't support split_k > 1
constexpr ck::index_t split_k = 1;
const auto in_g_n_c_wis_desc =
......@@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
return true;
}
bool run_grouped_conv_bwd_weight_example(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return false;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return run_grouped_conv_bwd_weight<3>(config, conv_param);
}
return false;
}
......@@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
BetaDataType,
HDataType,
AccDataType,
AccDataType,
HElementOp,
2,
1>;
Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> save_mean({M});
Tensor<AccDataType> save_inv_std({M});
auto ref_gemm = ReferenceGemm{};
auto ref_gemm_invoker = ref_gemm.MakeInvoker();
......@@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
auto ref_layernorm_argument = ref_layernorm.MakeArgument(
e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon);
e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon);
ref_layernorm_invoker.Run(ref_layernorm_argument);
}
......
add_custom_target(example_cgemm_xdl)
add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp)
if(result EQUAL 0)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16)
endif()
add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16)
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16)
endif()
add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16)
add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
endif()
add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp)
if(result EQUAL 0)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8)
endif()
add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4)
add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp)
add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int4)
endif()
add_custom_target(example_batched_gemm_xdl)
add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32)
endif()
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32)
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16)
endif()
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16)
add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp)
if(result EQUAL 0)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16)
endif()
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16)
add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
if(result EQUAL 0)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8)
endif()
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp)
if(result EQUAL 0)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
endif()
add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp)
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
endif()
add_custom_target(example_contraction)
add_custom_target(example_contraction_scale)
add_custom_target(example_contraction_bilinear)
# FP32
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32)
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32)
add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16)
add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16)
add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16)
add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16)
# FP64
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64)
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64)
add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32)
add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32)
# FP16
add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32)
add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32)
# BF16
add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32)
add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32)
add_dependencies(example_contraction example_contraction_scale)
add_dependencies(example_contraction example_contraction_bilinear)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using F64 = double;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Generic instances for fp32, fp16 and bf16 data types.
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceKK_Generic = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceKN_Generic = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceMK_Generic = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceMN_Generic = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
// Fp64 instances.
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "common_instances.hpp"
using ADataType = BF16;
using BDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = BF16;
using DDataType = BF16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = BF16;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
#include "run_contraction_bilinear_example.inc"
int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "common_instances.hpp"
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
#include "run_contraction_bilinear_example.inc"
int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#include "common_instances.hpp"
using ADataType = F32;
using BDataType = F32;
......@@ -32,6 +13,7 @@ using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
......@@ -41,253 +23,64 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceOpInstanceKKNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceKNNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceMKNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceMNNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
// clang-format on
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
// D[M0, M1, N0, N1]
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
float alpha = 1.f;
float beta = 1.f;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 28)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t M0 = std::stoi(argv[4]);
const ck::index_t M1 = std::stoi(argv[5]);
const ck::index_t N0 = std::stoi(argv[6]);
const ck::index_t N1 = std::stoi(argv[7]);
const ck::index_t K0 = std::stoi(argv[8]);
const ck::index_t K1 = std::stoi(argv[9]);
a_ms_ks_lengths = {M0, M1, K0, K1};
a_ms_ks_strides = {
std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])};
b_ns_ks_lengths = {N0, N1, K0, K1};
b_ns_ks_strides = {
std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])};
d_ms_ns_lengths = {M0, M1, N0, N1};
d_ms_ns_strides = {
std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])};
e_ms_ns_lengths = {M0, M1, N0, N1};
e_ms_ns_strides = {
std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])};
alpha = std::stof(argv[26]);
beta = std::stof(argv[27]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n");
printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n");
printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
printf("arg26 to 27: alpha, beta\n");
exit(0);
}
Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
Tensor<EDataType> d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides);
Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
d_device_buf.ToDevice(d_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M =
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
ck::index_t N = ck::accumulate_n<ck::index_t>(
e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>(
a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp>;
auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument =
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
d_ms_ns(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
#include "run_contraction_bilinear_example.inc"
return 0;
}
int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment