Commit 2927524e authored by wangshaojie6's avatar wangshaojie6
Browse files

Merge branch 'add_get_work_space' into bwd_weight_bf16_splitk

parents 35c977cd 2bfc08f1
...@@ -1229,14 +1229,13 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1229,14 +1229,13 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static size_t GetWorkSpaceSize(const Argument& arg) static size_t GetWorkSpaceSize(const Argument& arg)
{ {
size_t WorkSpaceSize = 0; size_t WorkSpaceSize = 0;
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value) if(arg.k_batch_ > 1)
{
WorkSpaceSize =
arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float);
}
else
{ {
WorkSpaceSize = arg.Conv_K_ * 0; if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize =
arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float);
}
} }
return WorkSpaceSize; return WorkSpaceSize;
} }
...@@ -1245,14 +1244,13 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1245,14 +1244,13 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static size_t GetWorkSpaceSize(const Argument& arg) static size_t GetWorkSpaceSize(const Argument& arg)
{ {
size_t WorkSpaceSize = 0; size_t WorkSpaceSize = 0;
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value) if(arg.k_batch_ > 1)
{
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
arg.filter_spatial_lengths_[1] * sizeof(float);
}
else
{ {
WorkSpaceSize = arg.Conv_K_ * 0; if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
arg.filter_spatial_lengths_[1] * sizeof(float);
}
} }
return WorkSpaceSize; return WorkSpaceSize;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment