Commit a26c802d authored by wangshaojie6's avatar wangshaojie6
Browse files

add some code for bfp16

parent 2927524e
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" #include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "reference_conv_backward_weight.hpp" #include "reference_conv_backward_weight.hpp"
using InDataType = ck::half_t; using InDataType = ck::bhalf_t;
using WeiDataType = ck::half_t; using WeiDataType = ck::bhalf_t;
using OutDataType = ck::half_t; using OutDataType = ck::bhalf_t;
using AccDataType = float; using AccDataType = float;
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -962,6 +962,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -962,6 +962,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value) if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{ {
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{
if(kbatch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm, GridwiseGemm,
...@@ -969,7 +971,26 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -969,7 +971,26 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
true>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatForBf16,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation, OutElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
...@@ -978,7 +999,10 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -978,7 +999,10 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
Run(kernel); Run(kernel);
} }
}
else else
{
if(kbatch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm, GridwiseGemm,
...@@ -986,7 +1010,26 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -986,7 +1010,26 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatForBf16,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation, OutElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
...@@ -996,6 +1039,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -996,6 +1039,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
Run(kernel); Run(kernel);
} }
} }
}
else else
{ {
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment