Unverified Commit 77d9529d authored by Robin Dong's avatar Robin Dong Committed by GitHub
Browse files

[CUDA] fix CUDA memory error by reducing block number (fixed #4315) (#4327)

parent b5502d19
...@@ -38,20 +38,20 @@ void cuda_histogram( ...@@ -38,20 +38,20 @@ void cuda_histogram(
if (leaf_num_data == num_data) { if (leaf_num_data == num_data) {
if (use_all_features) { if (use_all_features) {
if (!is_constant_hessian) if (!is_constant_hessian)
histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, histogram16<<<num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
else else
histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, histogram16<<<num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
} else { } else {
if (!is_constant_hessian) if (!is_constant_hessian)
histogram16_fulldata<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, histogram16_fulldata<<<num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
else else
histogram16_fulldata<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, histogram16_fulldata<<<num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
} }
...@@ -59,20 +59,20 @@ void cuda_histogram( ...@@ -59,20 +59,20 @@ void cuda_histogram(
if (use_all_features) { if (use_all_features) {
// seems all features is always enabled, so this should be the same as fulldata // seems all features is always enabled, so this should be the same as fulldata
if (!is_constant_hessian) if (!is_constant_hessian)
histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, histogram16<<<num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
else else
histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, histogram16<<<num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
} else { } else {
if (!is_constant_hessian) if (!is_constant_hessian)
histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, histogram16<<<num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
else else
histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, histogram16<<<num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
} }
...@@ -81,20 +81,20 @@ void cuda_histogram( ...@@ -81,20 +81,20 @@ void cuda_histogram(
if (leaf_num_data == num_data) { if (leaf_num_data == num_data) {
if (use_all_features) { if (use_all_features) {
if (!is_constant_hessian) if (!is_constant_hessian)
histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, histogram64<<<num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
else else
histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, histogram64<<<num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
} else { } else {
if (!is_constant_hessian) if (!is_constant_hessian)
histogram64_fulldata<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, histogram64_fulldata<<<num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
else else
histogram64_fulldata<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, histogram64_fulldata<<<num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
} }
...@@ -102,20 +102,20 @@ void cuda_histogram( ...@@ -102,20 +102,20 @@ void cuda_histogram(
if (use_all_features) { if (use_all_features) {
// seems all features is always enabled, so this should be the same as fulldata // seems all features is always enabled, so this should be the same as fulldata
if (!is_constant_hessian) if (!is_constant_hessian)
histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, histogram64<<<num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
else else
histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, histogram64<<<num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
} else { } else {
if (!is_constant_hessian) if (!is_constant_hessian)
histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, histogram64<<<num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
else else
histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, histogram64<<<num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
arg3, arg4, arg5, arg3, arg4, arg5,
arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature); arg6_const, arg7, arg8, static_cast<acc_type*>(arg9), exp_workgroups_per_feature);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment