"docs/vscode:/vscode.git/clone" did not exist on "86d5ec3d0f1b5c62b6018fdf1588d9895e47e48a"
Commit 8f41bd8e authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 7f65ac05 d7f05fb9
* @zjing14 @junliume @illsilin @carlushuang @aosewski * @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
# Documentation files # Documentation files
docs/* @ROCm/rocm-documentation docs/* @ROCm/rocm-documentation
*.md @ROCm/rocm-documentation *.md @ROCm/rocm-documentation
......
...@@ -20,6 +20,9 @@ endif() ...@@ -20,6 +20,9 @@ endif()
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp) add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp)
target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp)
target_link_libraries(client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
endif() endif()
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::bf8_t;
using WeiDataType = ck::f8_t;
using OutDataType = ck::f8_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using AComputeType = ck::bf8_t;
using BComputeType = ck::f8_t;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
AComputeType,
BComputeType>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
rocm-docs-core==0.38.0 rocm-docs-core==0.38.1
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
...@@ -111,7 +111,7 @@ requests==2.31.0 ...@@ -111,7 +111,7 @@ requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.38.0 rocm-docs-core==0.38.1
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
......
...@@ -5,7 +5,9 @@ add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) ...@@ -5,7 +5,9 @@ add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp16_comp_fp8 convnd_fwd_xdl_fp16_comp_fp8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp)
add_example_executable(example_convnd_fwd_xdl_bf8_fp8 convnd_fwd_xdl_bf8_fp8.cpp)
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::bf8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = ck::f8_t;
using OutDataType = ck::f8_t;
using AComputeType = ck::bf8_t;
using BComputeType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
AComputeType,
BComputeType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using OutDataType = ck::half_t;
using ComputeType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
ComputeType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
...@@ -3,6 +3,9 @@ add_custom_target(example_grouped_conv_bwd_data) ...@@ -3,6 +3,9 @@ add_custom_target(example_grouped_conv_bwd_data)
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8)
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
......
...@@ -34,6 +34,8 @@ static constexpr auto ConvBwdDataDefault = ...@@ -34,6 +34,8 @@ static constexpr auto ConvBwdDataDefault =
using FP16 = ck::half_t; using FP16 = ck::half_t;
using FP32 = float; using FP32 = float;
using FP8 = ck::f8_t;
using BF8 = ck::bf8_t;
struct ExecutionConfig final struct ExecutionConfig final
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
#include "common.hpp"
using OutDataType = FP16;
using WeiDataType = FP16;
using AccDataType = FP32;
using CShuffleDataType = FP16;
using DsDataType = ck::Tuple<>;
using InDataType = FP16;
using AComputeType = BF8;
using BComputeType = FP8;
using OutLayout = ck::tensor_layout::convolution::GNHWK;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
using DsLayout = ck::Tuple<>;
using InLayout = ck::tensor_layout::convolution::GNHWC;
using OutElementOp = PassThrough;
using WeiElementOp = PassThrough;
using InElementOp = PassThrough;
static constexpr auto LoopSched = ck::make_default_loop_scheduler();
// clang-format off
using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| Loop| ACompute| BCompute|
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, AComputeType, BComputeType>;
// clang-format on
#include "run_grouped_conv_bwd_data_example.inc"
int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); }
...@@ -104,6 +104,20 @@ ...@@ -104,6 +104,20 @@
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@ #cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif #endif
//
// CK kernels which support XDL (MI series)
//
#ifndef CK_USE_XDL
#cmakedefine CK_USE_XDL @CK_USE_XDL@
#endif
//
// CK Kernels which support WMMA (recent Navi series)
//
#ifndef CK_USE_WMMA
#cmakedefine CK_USE_WMMA @CK_USE_WMMA@
#endif
// clang-format on // clang-format on
#endif // CK_CONFIG_H_IN #endif // CK_CONFIG_H_IN
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -15,7 +15,6 @@ namespace ck { ...@@ -15,7 +15,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM, template <ck::index_t NumDimM,
ck::index_t NumDimN, ck::index_t NumDimN,
ck::index_t NumDimK, ck::index_t NumDimK,
...@@ -26,7 +25,9 @@ template <ck::index_t NumDimM, ...@@ -26,7 +25,9 @@ template <ck::index_t NumDimM,
typename ComputeDataType, typename ComputeDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false> ck::enable_if_t<(NumDimM == 2 || NumDimM == 6) && (NumDimN == 2 || NumDimN == 6) &&
(NumDimK == 2 || NumDimK == 6),
bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{ {
// Argument // Argument
...@@ -60,9 +61,28 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -60,9 +61,28 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { auto f_ms_ns = [&](auto m0,
const ck::index_t K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; auto m1,
const ck::index_t K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; auto m2,
auto m3,
auto m4,
auto m5,
auto n0,
auto n1,
auto n2,
auto n3,
auto n4,
auto n5) {
const ck::index_t K0 = arg.a_ms_ks_.mDesc.GetLengths()[NumDimM];
const ck::index_t K1 = arg.a_ms_ks_.mDesc.GetLengths()[NumDimM + 1];
const ck::index_t K2 =
NumDimK >= 3 ? arg.a_ms_ks_.mDesc.GetLengths()[NumDimM + 2] : 1;
const ck::index_t K3 =
NumDimK >= 4 ? arg.a_ms_ks_.mDesc.GetLengths()[NumDimM + 3] : 1;
const ck::index_t K4 =
NumDimK >= 5 ? arg.a_ms_ks_.mDesc.GetLengths()[NumDimM + 4] : 1;
const ck::index_t K5 =
NumDimK >= 6 ? arg.a_ms_ks_.mDesc.GetLengths()[NumDimM + 5] : 1;
AccDataType v_acc = 0; AccDataType v_acc = 0;
...@@ -70,32 +90,96 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -70,32 +90,96 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{ {
for(ck::index_t k1 = 0; k1 < K1; ++k1) for(ck::index_t k1 = 0; k1 < K1; ++k1)
{ {
// Simulate the possible casting when ComputeDataType is different than the for(ck::index_t k2 = 0; k2 < K2; ++k2)
// A/B data types {
ComputeDataType v_a_compute_input = for(ck::index_t k3 = 0; k3 < K3; ++k3)
ck::type_convert<ComputeDataType>(arg.a_ms_ks_(m0, m1, k0, k1)); {
ComputeDataType v_b_compute_input = for(ck::index_t k4 = 0; k4 < K4; ++k4)
ck::type_convert<ComputeDataType>(arg.b_ns_ks_(n0, n1, k0, k1)); {
for(ck::index_t k5 = 0; k5 < K5; ++k5)
AccDataType v_a; {
AccDataType v_b; ComputeDataType v_a_compute_input;
ComputeDataType v_b_compute_input;
arg.a_element_op_(v_a, ck::type_convert<AccDataType>(v_a_compute_input));
arg.b_element_op_(v_b, ck::type_convert<AccDataType>(v_b_compute_input)); // Simulate the possible casting when ComputeDataType is
// different than the A/B data types
v_acc += v_a * v_b; if constexpr(NumDimK == 2)
{
v_a_compute_input = ck::type_convert<ComputeDataType>(
arg.a_ms_ks_(m0, m1, k0, k1));
v_b_compute_input = ck::type_convert<ComputeDataType>(
arg.b_ns_ks_(n0, n1, k0, k1));
}
else if constexpr(NumDimK == 6)
{
v_a_compute_input = ck::type_convert<
ComputeDataType>(arg.a_ms_ks_(
m0, m1, m2, m3, m4, m5, k0, k1, k2, k3, k4, k5));
v_b_compute_input = ck::type_convert<
ComputeDataType>(arg.b_ns_ks_(
n0, n1, n2, n3, n4, n5, k0, k1, k2, k3, k4, k5));
}
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<AccDataType>(v_a_compute_input));
arg.b_element_op_(
v_b, ck::type_convert<AccDataType>(v_b_compute_input));
v_acc += v_a * v_b;
}
}
}
}
} }
} }
arg.c_ms_ns_(m0, m1, n0, n1) = ck::type_convert<CDataType>(v_acc); if constexpr(NumDimK == 2)
{
arg.c_ms_ns_(m0, m1, n0, n1) = ck::type_convert<CDataType>(v_acc);
}
else if constexpr(NumDimK == 6)
{
arg.c_ms_ns_(m0, m1, m2, m3, m4, m5, n0, n1, n2, n3, n4, n5) =
ck::type_convert<CDataType>(v_acc);
}
}; };
make_ParallelTensorFunctor(f_ms_ns, if constexpr(NumDimK == 2)
arg.c_ms_ns_.mDesc.GetLengths()[0], {
arg.c_ms_ns_.mDesc.GetLengths()[1], make_ParallelTensorFunctor(f_ms_ns,
arg.c_ms_ns_.mDesc.GetLengths()[2], arg.c_ms_ns_.mDesc.GetLengths()[0],
arg.c_ms_ns_.mDesc.GetLengths()[3])( arg.c_ms_ns_.mDesc.GetLengths()[1],
std::thread::hardware_concurrency()); 1,
1,
1,
1,
arg.c_ms_ns_.mDesc.GetLengths()[2],
arg.c_ms_ns_.mDesc.GetLengths()[3],
1,
1,
1,
1)(std::thread::hardware_concurrency());
}
else if constexpr(NumDimK == 6)
{
make_ParallelTensorFunctor(f_ms_ns,
arg.c_ms_ns_.mDesc.GetLengths()[0],
arg.c_ms_ns_.mDesc.GetLengths()[1],
arg.c_ms_ns_.mDesc.GetLengths()[2],
arg.c_ms_ns_.mDesc.GetLengths()[3],
arg.c_ms_ns_.mDesc.GetLengths()[4],
arg.c_ms_ns_.mDesc.GetLengths()[5],
arg.c_ms_ns_.mDesc.GetLengths()[6],
arg.c_ms_ns_.mDesc.GetLengths()[7],
arg.c_ms_ns_.mDesc.GetLengths()[8],
arg.c_ms_ns_.mDesc.GetLengths()[9],
arg.c_ms_ns_.mDesc.GetLengths()[10],
arg.c_ms_ns_.mDesc.GetLengths()[11])(
std::thread::hardware_concurrency());
}
return 0; return 0;
} }
......
...@@ -326,6 +326,42 @@ using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple< ...@@ -326,6 +326,42 @@ using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple<
// clang-format on // clang-format on
>; >;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bf8_f8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>
#endif
// clang-format on
>;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -301,6 +301,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -301,6 +301,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(op_ptrs);
} }
#endif #endif
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
if constexpr(is_same_v<InDataType, ck::bf8_t> && is_same_v<WeiDataType, ck::f8_t> &&
is_same_v<OutDataType, ck::f8_t> && is_same_v<AComputeType, ck::bf8_t> &&
is_same_v<BComputeType, ck::f8_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> && is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
......
...@@ -369,6 +369,24 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( ...@@ -369,6 +369,24 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(
BF8>>>& instances); BF8>>>& instances);
#endif #endif
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF8,
F8,
Empty_Tuple,
F8,
PassThrough,
PassThrough,
PassThrough,
BF8,
F8>>>& instances);
#endif
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -408,6 +408,37 @@ struct Tensor ...@@ -408,6 +408,37 @@ struct Tensor
mDesc.GetLengths()[5])(num_thread); mDesc.GetLengths()[5])(num_thread);
break; break;
} }
case 12: {
auto f = [&](auto i0,
auto i1,
auto i2,
auto i3,
auto i4,
auto i5,
auto i6,
auto i7,
auto i8,
auto i9,
auto i10,
auto i11) {
(*this)(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11) =
g(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11);
};
make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0],
mDesc.GetLengths()[1],
mDesc.GetLengths()[2],
mDesc.GetLengths()[3],
mDesc.GetLengths()[4],
mDesc.GetLengths()[5],
mDesc.GetLengths()[6],
mDesc.GetLengths()[7],
mDesc.GetLengths()[8],
mDesc.GetLengths()[9],
mDesc.GetLengths()[10],
mDesc.GetLengths()[11])(num_thread);
break;
}
default: throw std::runtime_error("unspported dimension"); default: throw std::runtime_error("unspported dimension");
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment