Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
77d9529d
"src/vscode:/vscode.git/clone" did not exist on "dc6995742a5284a1e942978e2542fc49adda9ea1"
Unverified
Commit
77d9529d
authored
Jun 28, 2021
by
Robin Dong
Committed by
GitHub
Jun 28, 2021
Browse files
[CUDA] fix CUDA memory error by reducing block number (fixed #4315) (#4327)
parent
b5502d19
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
16 deletions
+16
-16
src/treelearner/cuda_kernel_launcher.cu
src/treelearner/cuda_kernel_launcher.cu
+16
-16
No files found.
src/treelearner/cuda_kernel_launcher.cu
View file @
77d9529d
...
@@ -38,20 +38,20 @@ void cuda_histogram(
...
@@ -38,20 +38,20 @@ void cuda_histogram(
if
(
leaf_num_data
==
num_data
)
{
if
(
leaf_num_data
==
num_data
)
{
if
(
use_all_features
)
{
if
(
use_all_features
)
{
if
(
!
is_constant_hessian
)
if
(
!
is_constant_hessian
)
histogram16
<<<
16
*
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram16
<<<
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
else
else
histogram16
<<<
16
*
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram16
<<<
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
}
else
{
}
else
{
if
(
!
is_constant_hessian
)
if
(
!
is_constant_hessian
)
histogram16_fulldata
<<<
16
*
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram16_fulldata
<<<
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
else
else
histogram16_fulldata
<<<
16
*
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram16_fulldata
<<<
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
}
}
...
@@ -59,20 +59,20 @@ void cuda_histogram(
...
@@ -59,20 +59,20 @@ void cuda_histogram(
if
(
use_all_features
)
{
if
(
use_all_features
)
{
// seems all features is always enabled, so this should be the same as fulldata
// seems all features is always enabled, so this should be the same as fulldata
if
(
!
is_constant_hessian
)
if
(
!
is_constant_hessian
)
histogram16
<<<
16
*
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram16
<<<
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
else
else
histogram16
<<<
16
*
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram16
<<<
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
}
else
{
}
else
{
if
(
!
is_constant_hessian
)
if
(
!
is_constant_hessian
)
histogram16
<<<
16
*
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram16
<<<
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
else
else
histogram16
<<<
16
*
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram16
<<<
num_workgroups
,
16
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
}
}
...
@@ -81,20 +81,20 @@ void cuda_histogram(
...
@@ -81,20 +81,20 @@ void cuda_histogram(
if
(
leaf_num_data
==
num_data
)
{
if
(
leaf_num_data
==
num_data
)
{
if
(
use_all_features
)
{
if
(
use_all_features
)
{
if
(
!
is_constant_hessian
)
if
(
!
is_constant_hessian
)
histogram64
<<<
4
*
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram64
<<<
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
else
else
histogram64
<<<
4
*
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram64
<<<
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
}
else
{
}
else
{
if
(
!
is_constant_hessian
)
if
(
!
is_constant_hessian
)
histogram64_fulldata
<<<
4
*
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram64_fulldata
<<<
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
else
else
histogram64_fulldata
<<<
4
*
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram64_fulldata
<<<
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
}
}
...
@@ -102,20 +102,20 @@ void cuda_histogram(
...
@@ -102,20 +102,20 @@ void cuda_histogram(
if
(
use_all_features
)
{
if
(
use_all_features
)
{
// seems all features is always enabled, so this should be the same as fulldata
// seems all features is always enabled, so this should be the same as fulldata
if
(
!
is_constant_hessian
)
if
(
!
is_constant_hessian
)
histogram64
<<<
4
*
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram64
<<<
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
else
else
histogram64
<<<
4
*
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram64
<<<
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
}
else
{
}
else
{
if
(
!
is_constant_hessian
)
if
(
!
is_constant_hessian
)
histogram64
<<<
4
*
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram64
<<<
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
else
else
histogram64
<<<
4
*
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
histogram64
<<<
num_workgroups
,
64
,
0
,
stream
>>>
(
arg0
,
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
arg3
,
arg4
,
arg5
,
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
arg6_const
,
arg7
,
arg8
,
static_cast
<
acc_type
*>
(
arg9
),
exp_workgroups_per_feature
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment