"vscode:/vscode.git/clone" did not exist on "692b7a907d64f9ca375eb09cc211e632b7767693"
Commit a26c802d authored by wangshaojie6's avatar wangshaojie6
Browse files

add some code for bfp16

parent 2927524e
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" #include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "reference_conv_backward_weight.hpp" #include "reference_conv_backward_weight.hpp"
using InDataType = ck::half_t; using InDataType = ck::bhalf_t;
using WeiDataType = ck::half_t; using WeiDataType = ck::bhalf_t;
using OutDataType = ck::half_t; using OutDataType = ck::bhalf_t;
using AccDataType = float; using AccDataType = float;
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -963,37 +963,81 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -963,37 +963,81 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{ {
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< if(kbatch == 1)
GridwiseGemm, {
ADataType, // TODO: distiguish A/B datatype const auto kernel = kernel_gemm_xdlops_bwd_weight<
CDataType, GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, ADataType, // TODO: distiguish A/B datatype
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, CDataType,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
OutElementwiseOperation, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
InElementwiseOperation, remove_reference_t<
WeiElementwiseOperation, DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>, OutElementwiseOperation,
true>; InElementwiseOperation,
WeiElementwiseOperation,
Run(kernel); remove_reference_t<DeviceOp::Block2CTileMap>,
true>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatForBf16,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
true>;
Run(kernel);
}
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< if(kbatch == 1)
GridwiseGemm, {
ADataType, // TODO: distiguish A/B datatype const auto kernel = kernel_gemm_xdlops_bwd_weight<
CDataType, GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, ADataType, // TODO: distiguish A/B datatype
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, CDataType,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
OutElementwiseOperation, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
InElementwiseOperation, remove_reference_t<
WeiElementwiseOperation, DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>, OutElementwiseOperation,
false>; InElementwiseOperation,
WeiElementwiseOperation,
Run(kernel); remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatForBf16,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
Run(kernel);
}
} }
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment