"vscode:/vscode.git/clone" did not exist on "c41b30a8e6d46d4774806e493d7aa35fd2b8f366"
Unverified Commit 42facfc6 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Add conv bwd weight fp16 comp bf8 fp8 op, instances and example (#945)



* Add f8 bf8 gemm example

* Add element-wise ops

* Add intrinsics

* Update reference calculation

* Add an additional type option for xdlops gemm

* Fix build process

* Add bf8 to buffer addressing

* Update blockwise op, split typeA and typeB

* Update for compatibility

* Uppdate naming to f8->fp8

* Update naming

* Format

* Update naming (#937)

* Add a client example

* Add computetypes to device and gridwise ops

* Add instances, update instance factory

* Format

* Fix a flag

* Add ckProfiler mode

* Fix typos

* Add an example

* Add bf8 generator

* add bf8 mfma; fixed type_convert for bf8

* move verfication ahead of timing

* Update reference calculation

* Fix reference

* Narrow down float init range

* Fix bf8 bf8 mfma

* Add bf8 @ fp8 mfma

* Update example

* Update instances

* Update profiler api

* Update for compatibility

* Format

* Remove extra example

* Clean up

* workaround convert

---------
Co-authored-by: default avatarJing Zhang <jizha@amd.com>
parent e921e1f0
...@@ -11,6 +11,12 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -11,6 +11,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
if(result EQUAL 0) if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
endif() endif()
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
endif()
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
......
...@@ -23,6 +23,12 @@ ...@@ -23,6 +23,12 @@
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
#ifdef CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
......
...@@ -65,6 +65,15 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedC ...@@ -65,6 +65,15 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedC
5, // CThreadTransferSrcDstVectorDim 5, // CThreadTransferSrcDstVectorDim
4>; // CThreadTransferDstScalarPerVector 4>; // CThreadTransferDstScalarPerVector
template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
...@@ -67,6 +67,15 @@ using DeviceConvBwdWeightInstance = ...@@ -67,6 +67,15 @@ using DeviceConvBwdWeightInstance =
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
...@@ -66,6 +66,15 @@ using DeviceConvBwdWeightInstance = ...@@ -66,6 +66,15 @@ using DeviceConvBwdWeightInstance =
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using InDataType = F16;
using WeiDataType = F16;
using OutDataType = F16;
using AccDataType = F32;
using ComputeTypeA = BF8;
using ComputeTypeB = F8;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
1, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2, // CBlockTransferScalarPerVector_NWaveNPerXdl
ComputeTypeA, // ComputeTypeA
ComputeTypeB>; // ComputeTypeB
template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ComputeTypeA,
ComputeTypeB>;
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param) const ck::utils::conv::ConvParam& conv_param)
...@@ -46,8 +37,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -46,8 +37,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 0.2});
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5}); out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.1, 0.1});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
...@@ -113,18 +104,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -113,18 +104,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
return true; return true;
} }
float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); invoker.Run(argument, StreamConfig{nullptr, false});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl
<< "DeviceOp: " << conv.GetTypeString() << std::endl;
if(config.do_verification) if(config.do_verification)
{ {
...@@ -148,6 +128,19 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -148,6 +128,19 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData);
} }
float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl
<< "DeviceOp: " << conv.GetTypeString() << std::endl;
return true; return true;
} }
......
...@@ -20,7 +20,9 @@ template <ck::index_t NDimSpatial, ...@@ -20,7 +20,9 @@ template <ck::index_t NDimSpatial,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeight : public BaseOperator struct DeviceGroupedConvBwdWeight : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
......
...@@ -48,7 +48,8 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -48,7 +48,8 @@ struct ComputePtrOffsetOfStridedBatch
} // namespace } // namespace
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatC, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -64,8 +65,8 @@ __global__ void ...@@ -64,8 +65,8 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_xdlops_bwd_weight( kernel_batched_gemm_xdlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid, const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -91,7 +92,7 @@ __global__ void ...@@ -91,7 +92,7 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
...@@ -163,7 +164,9 @@ template <ck::index_t NDimSpatial, ...@@ -163,7 +164,9 @@ template <ck::index_t NDimSpatial,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial, : public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout, InLayout,
...@@ -174,7 +177,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -174,7 +177,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
OutDataType, OutDataType,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation,
ComputeTypeA,
ComputeTypeB>
{ {
using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle;
...@@ -1045,7 +1050,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1045,7 +1050,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
...@@ -1090,7 +1096,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1090,7 +1096,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true, true,
true>; true,
1,
PipelineVersion::v1,
ComputeTypeA,
ComputeTypeB>;
// Argument // Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
...@@ -1217,8 +1227,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1217,8 +1227,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
InElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
OutElementwiseOperation b_element_op_; InElementwiseOperation b_element_op_;
WeiElementwiseOperation c_element_op_; WeiElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
...@@ -1281,7 +1291,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1281,7 +1291,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType,
BDataType,
CDataType, CDataType,
OutElementwiseOperation, OutElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
......
...@@ -185,7 +185,8 @@ struct PassThrough ...@@ -185,7 +185,8 @@ struct PassThrough
template <> template <>
__host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const __host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
{ {
y = type_convert<bf8_t>(x); // to-do: fix half_t to bf8_t convert
y = ck::type_convert<bf8_t>(ck::type_convert<float>(x));
} }
#endif #endif
}; };
......
...@@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen ...@@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatC, typename FloatC,
typename AGridDesc_B_K0_M_K1, typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1, typename BGridDesc_B_K0_N_K1,
...@@ -153,8 +154,8 @@ __global__ void ...@@ -153,8 +154,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
...@@ -181,21 +182,22 @@ __global__ void ...@@ -181,21 +182,22 @@ __global__ void
c_element_op, c_element_op,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc; ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc; ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = c_block_cluster_adaptor; ignore = c_block_cluster_adaptor;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
...@@ -242,7 +244,9 @@ template <index_t BlockSize, ...@@ -242,7 +244,9 @@ template <index_t BlockSize,
bool ABlockLdsExtraM1Wrw = false, bool ABlockLdsExtraM1Wrw = false,
bool BBlockLdsExtraN1Wrw = false, bool BBlockLdsExtraN1Wrw = false,
index_t NumGemmKPrefetchStage = 1, index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// denorm test fix, required to work around fp16 mfma issue // denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update // when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB,
// throughout this file
#if CK_WORKAROUND_DENORM_FIX #if CK_WORKAROUND_DENORM_FIX
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>; using FloatAAdjusted =
conditional_t<is_same_v<ComputeTypeA, ck::half_t>, ck::bhalf_t, ComputeTypeA>;
using FloatBAdjusted =
conditional_t<is_same_v<ComputeTypeB, ck::half_t>, ck::bhalf_t, ComputeTypeB>;
#else #else
using FloatABAdjusted = FloatAB; using FloatAAdjusted = ComputeTypeA;
using FloatBAdjusted = ComputeTypeB;
#endif #endif
// M0/M1/M1Padding // M0/M1/M1Padding
...@@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto c_block_size = constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), return math::max((a_block_space_size * sizeof(FloatAAdjusted) +
b_block_space_size * sizeof(FloatBAdjusted)),
c_block_size * sizeof(FloatC)); c_block_size * sizeof(FloatC));
} }
...@@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
...@@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
Sequence<1, K0PerBlock, MPerBlock, K1>, Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatABAdjusted, FloatAAdjusted,
decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc), decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
Sequence<1, K0PerBlock, NPerBlock, K1>, Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatABAdjusted, FloatBAdjusted,
decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc), decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -733,12 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -733,12 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(K1, MfmaSelector<FloatABAdjusted, MPerXDL, NPerXDL>::selected_mfma.k_per_blk); math::max(K1,
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted>::selected_mfma
.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatABAdjusted, FloatAAdjusted,
FloatABAdjusted, FloatBAdjusted,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
...@@ -758,10 +770,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -758,10 +770,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size, static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
b_k0_n_k1_block_desc.GetElementSpaceSize()); b_k0_n_k1_block_desc.GetElementSpaceSize());
// gridwise GEMM pipeline // gridwise GEMM pipeline
......
...@@ -32,8 +32,12 @@ enum struct MfmaInstr ...@@ -32,8 +32,12 @@ enum struct MfmaInstr
mfma_f64_16x16x4f64, mfma_f64_16x16x4f64,
mfma_f32_32x32x16f8f8, mfma_f32_32x32x16f8f8,
mfma_f32_16x16x32f8f8, mfma_f32_16x16x32f8f8,
mfma_f32_32x32x16bf8bf8,
mfma_f32_16x16x32bf8bf8,
mfma_f32_32x32x16f8bf8, mfma_f32_32x32x16f8bf8,
mfma_f32_16x16x32f8bf8 mfma_f32_16x16x32f8bf8,
mfma_f32_32x32x16bf8f8,
mfma_f32_16x16x32bf8f8
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
...@@ -504,6 +508,52 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8> ...@@ -504,6 +508,52 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
}; };
#endif #endif
#if defined CK_ENABLE_BF8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8>
...@@ -550,6 +600,52 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8> ...@@ -550,6 +600,52 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
}; };
#endif #endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
template <typename base_type, template <typename base_type,
index_t MPerXdlops, index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -710,6 +806,20 @@ struct MfmaSelector ...@@ -710,6 +806,20 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_BF8
template <>
static constexpr auto GetMfma<bf8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
}
template <>
static constexpr auto GetMfma<bf8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>() static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
...@@ -724,6 +834,20 @@ struct MfmaSelector ...@@ -724,6 +834,20 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{
return MfmaInstr::mfma_f32_32x32x16bf8f8;
}
template <>
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}
#endif
static constexpr auto selected_mfma = static constexpr auto selected_mfma =
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{}; mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
...@@ -931,8 +1055,12 @@ struct XdlopsGemm ...@@ -931,8 +1055,12 @@ struct XdlopsGemm
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value || is_same<base_type, f8_t>::value
#endif #endif
#if defined CK_ENABLE_BF8
|| is_same<base_type, bf8_t>::value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|| (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) || (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value)
#endif #endif
, ,
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
......
...@@ -420,6 +420,71 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -420,6 +420,71 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
}; };
#endif #endif
#if defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8bf8;
template <>
struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf8bf8;
template <>
struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8bf8; struct intrin_mfma_f32_32x32x16f8bf8;
...@@ -484,5 +549,70 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> ...@@ -484,5 +549,70 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
} }
}; };
#endif #endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8f8;
template <>
struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf8f8;
template <>
struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -221,7 +221,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x) ...@@ -221,7 +221,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion // convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x)); return type_convert<bf8_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
......
...@@ -25,6 +25,8 @@ template <ck::index_t NDimSpatial, ...@@ -25,6 +25,8 @@ template <ck::index_t NDimSpatial,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
typename ComputeTypeA = OutDataType,
typename ComputeTypeB = InDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false> typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
...@@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_out; ComputeTypeA v_out;
float v_in; ComputeTypeB v_in;
arg.out_element_op_( arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo))); v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
...@@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
v_acc += v_out * v_in; v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
} }
} }
} }
...@@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
float v_out; ComputeTypeA v_out;
float v_in; ComputeTypeB v_in;
arg.out_element_op_( arg.out_element_op_(
v_out, v_out,
...@@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
} }
} }
} }
...@@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5]) arg.input_.GetLengths()[5])
{ {
float v_out; ComputeTypeA v_out;
float v_in; ComputeTypeB v_in;
arg.out_element_op_(v_out, arg.out_element_op_(v_out,
ck::type_convert<float>( ck::type_convert<float>(
...@@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck::type_convert<float>( ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi))); arg.input_(g, n, c, di, hi, wi)));
v_acc += v_out * v_in; v_acc +=
type_convert<float>(v_out) * type_convert<float>(v_in);
} }
} }
} }
......
...@@ -19,6 +19,14 @@ using BF16 = ck::bhalf_t; ...@@ -19,6 +19,14 @@ using BF16 = ck::bhalf_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
#ifdef CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -133,6 +141,43 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple< ...@@ -133,6 +141,43 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple<
// clang-format on // clang-format on
>; >;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec>
using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances = std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
// generic instance
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>,
// instance for small conv.K
// for fp16 conv.K and conv.C must be divisible by 2
// since half_t atomic_add require scalar_per_x_vector % 2 == 0
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>
#endif
// clang-format on
>;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -216,6 +216,21 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances ...@@ -216,6 +216,21 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
BF8,
F8>>>& instances);
#endif
#ifdef DL_KERNELS #ifdef DL_KERNELS
// dl // dl
...@@ -464,7 +479,9 @@ template <ck::index_t NumDimSpatial, ...@@ -464,7 +479,9 @@ template <ck::index_t NumDimSpatial,
typename OutLayout, typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType> typename OutDataType,
typename ComputeTypeA,
typename ComputeTypeB>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvBwdWeight< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvBwdWeight<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
...@@ -475,7 +492,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -475,7 +492,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
OutDataType, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>> ck::tensor_operation::element_wise::PassThrough,
ComputeTypeA,
ComputeTypeB>>
{ {
using DeviceOp = DeviceGroupedConvBwdWeight<NumDimSpatial, using DeviceOp = DeviceGroupedConvBwdWeight<NumDimSpatial,
InLayout, InLayout,
...@@ -486,7 +505,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -486,7 +505,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
OutDataType, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>; ck::tensor_operation::element_wise::PassThrough,
ComputeTypeA,
ComputeTypeB>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -706,7 +727,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -706,7 +727,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t> &&
is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{ {
#ifdef DL_KERNELS #ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
...@@ -728,6 +751,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -728,6 +751,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
op_ptrs); op_ptrs);
} }
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> &&
is_same_v<ComputeTypeA, bf8_t> && is_same_v<ComputeTypeB, f8_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
op_ptrs);
}
#endif #endif
} }
} }
......
...@@ -111,6 +111,22 @@ struct GeneratorTensor_2<ck::f8_t> ...@@ -111,6 +111,22 @@ struct GeneratorTensor_2<ck::f8_t>
}; };
#endif #endif
#if defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_2<ck::bf8_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::bf8_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::bf8_t>(tmp);
}
};
#endif
template <typename T> template <typename T>
struct GeneratorTensor_3 struct GeneratorTensor_3
{ {
...@@ -162,6 +178,25 @@ struct GeneratorTensor_3<ck::f8_t> ...@@ -162,6 +178,25 @@ struct GeneratorTensor_3<ck::f8_t>
}; };
#endif #endif
#if defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_3<ck::bf8_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::bf8_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::bf8_t>(fp32_tmp);
}
};
#endif
template <typename T> template <typename T>
struct GeneratorTensor_4 struct GeneratorTensor_4
{ {
......
...@@ -4,7 +4,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT ...@@ -4,7 +4,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
if(DL_KERNELS) if(DL_KERNELS)
list(APPEND GROUPED_CONV3D_BWD_WEIGHT list(APPEND GROUPED_CONV3D_BWD_WEIGHT
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
BF8,
F8>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment