Unverified Commit a15a3704 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

Update CUDA treelearner according to changes introduced for linear trees (#3750)

* Update cuda_tree_learner.cpp

* Update cuda_tree_learner.h

* Update cuda.yml
parent f997a069
...@@ -47,6 +47,7 @@ jobs: ...@@ -47,6 +47,7 @@ jobs:
export ROOT_DOCKER_FOLDER=/LightGBM export ROOT_DOCKER_FOLDER=/LightGBM
cat > docker.env <<EOF cat > docker.env <<EOF
TASK=cuda TASK=cuda
METHOD=source
COMPILER=gcc COMPILER=gcc
GITHUB_ACTIONS=true GITHUB_ACTIONS=true
OS_NAME=linux OS_NAME=linux
......
...@@ -534,8 +534,8 @@ void CUDATreeLearner::InitGPU(int num_gpu) { ...@@ -534,8 +534,8 @@ void CUDATreeLearner::InitGPU(int num_gpu) {
copyDenseFeature(); copyDenseFeature();
} }
Tree* CUDATreeLearner::Train(const score_t* gradients, const score_t *hessians) { Tree* CUDATreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) {
Tree *ret = SerialTreeLearner::Train(gradients, hessians); Tree *ret = SerialTreeLearner::Train(gradients, hessians, is_first_tree);
return ret; return ret;
} }
......
...@@ -45,7 +45,7 @@ class CUDATreeLearner: public SerialTreeLearner { ...@@ -45,7 +45,7 @@ class CUDATreeLearner: public SerialTreeLearner {
~CUDATreeLearner(); ~CUDATreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) override; void ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) override;
Tree* Train(const score_t* gradients, const score_t *hessians); Tree* Train(const score_t* gradients, const score_t *hessians, bool is_first_tree);
void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override { void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override {
SerialTreeLearner::SetBaggingData(subset, used_indices, num_data); SerialTreeLearner::SetBaggingData(subset, used_indices, num_data);
if (subset == nullptr && used_indices != nullptr) { if (subset == nullptr && used_indices != nullptr) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment