Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
cd36ffea
Unverified
Commit
cd36ffea
authored
Nov 25, 2023
by
david-cortes
Committed by
GitHub
Nov 24, 2023
Browse files
[R-package] Fix inefficiency in retrieving pointers (#6208)
parent
516bde95
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
8 deletions
+16
-8
R-package/src/lightgbm_R.cpp
R-package/src/lightgbm_R.cpp
+16
-8
No files found.
R-package/src/lightgbm_R.cpp
View file @
cd36ffea
...
...
@@ -226,9 +226,10 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
int32_t
len
=
static_cast
<
int32_t
>
(
Rf_asInteger
(
len_used_row_indices
));
std
::
vector
<
int32_t
>
idxvec
(
len
);
// convert from one-based to zero-based index
const
int
*
used_row_indices_
=
INTEGER
(
used_row_indices
);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
for
(
int32_t
i
=
0
;
i
<
len
;
++
i
)
{
idxvec
[
i
]
=
static_cast
<
int32_t
>
(
INTEGER
(
used_row_indices
)
[
i
]
-
1
);
idxvec
[
i
]
=
static_cast
<
int32_t
>
(
used_row_indices
_
[
i
]
-
1
);
}
const
char
*
parameters_ptr
=
CHAR
(
PROTECT
(
Rf_asChar
(
parameters
)));
DatasetHandle
res
=
nullptr
;
...
...
@@ -339,18 +340,20 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
const
char
*
name
=
CHAR
(
PROTECT
(
Rf_asChar
(
field_name
)));
if
(
!
strcmp
(
"group"
,
name
)
||
!
strcmp
(
"query"
,
name
))
{
std
::
vector
<
int32_t
>
vec
(
len
);
const
int
*
field_data_
=
INTEGER
(
field_data
);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
for
(
int
i
=
0
;
i
<
len
;
++
i
)
{
vec
[
i
]
=
static_cast
<
int32_t
>
(
INTEGER
(
field_data
)
[
i
]);
vec
[
i
]
=
static_cast
<
int32_t
>
(
field_data
_
[
i
]);
}
CHECK_CALL
(
LGBM_DatasetSetField
(
R_ExternalPtrAddr
(
handle
),
name
,
vec
.
data
(),
len
,
C_API_DTYPE_INT32
));
}
else
if
(
!
strcmp
(
"init_score"
,
name
))
{
CHECK_CALL
(
LGBM_DatasetSetField
(
R_ExternalPtrAddr
(
handle
),
name
,
REAL
(
field_data
),
len
,
C_API_DTYPE_FLOAT64
));
}
else
{
std
::
vector
<
float
>
vec
(
len
);
const
double
*
field_data_
=
REAL
(
field_data
);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
for
(
int
i
=
0
;
i
<
len
;
++
i
)
{
vec
[
i
]
=
static_cast
<
float
>
(
REAL
(
field_data
)
[
i
]);
vec
[
i
]
=
static_cast
<
float
>
(
field_data
_
[
i
]);
}
CHECK_CALL
(
LGBM_DatasetSetField
(
R_ExternalPtrAddr
(
handle
),
name
,
vec
.
data
(),
len
,
C_API_DTYPE_FLOAT32
));
}
...
...
@@ -372,21 +375,24 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
if
(
!
strcmp
(
"group"
,
name
)
||
!
strcmp
(
"query"
,
name
))
{
auto
p_data
=
reinterpret_cast
<
const
int32_t
*>
(
res
);
// convert from boundaries to size
int
*
field_data_
=
INTEGER
(
field_data
);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
for
(
int
i
=
0
;
i
<
out_len
-
1
;
++
i
)
{
INTEGER
(
field_data
)
[
i
]
=
p_data
[
i
+
1
]
-
p_data
[
i
];
field_data
_
[
i
]
=
p_data
[
i
+
1
]
-
p_data
[
i
];
}
}
else
if
(
!
strcmp
(
"init_score"
,
name
))
{
auto
p_data
=
reinterpret_cast
<
const
double
*>
(
res
);
double
*
field_data_
=
REAL
(
field_data
);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
for
(
int
i
=
0
;
i
<
out_len
;
++
i
)
{
REAL
(
field_data
)
[
i
]
=
p_data
[
i
];
field_data
_
[
i
]
=
p_data
[
i
];
}
}
else
{
auto
p_data
=
reinterpret_cast
<
const
float
*>
(
res
);
double
*
field_data_
=
REAL
(
field_data
);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
for
(
int
i
=
0
;
i
<
out_len
;
++
i
)
{
REAL
(
field_data
)
[
i
]
=
p_data
[
i
];
field_data
_
[
i
]
=
p_data
[
i
];
}
}
UNPROTECT
(
1
);
...
...
@@ -611,10 +617,12 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
int
is_finished
=
0
;
int
int_len
=
Rf_asInteger
(
len
);
std
::
vector
<
float
>
tgrad
(
int_len
),
thess
(
int_len
);
const
double
*
grad_
=
REAL
(
grad
);
const
double
*
hess_
=
REAL
(
hess
);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (int_len >= 1024)
for
(
int
j
=
0
;
j
<
int_len
;
++
j
)
{
tgrad
[
j
]
=
static_cast
<
float
>
(
REAL
(
grad
)
[
j
]);
thess
[
j
]
=
static_cast
<
float
>
(
REAL
(
hess
)
[
j
]);
tgrad
[
j
]
=
static_cast
<
float
>
(
grad
_
[
j
]);
thess
[
j
]
=
static_cast
<
float
>
(
hess
_
[
j
]);
}
CHECK_CALL
(
LGBM_BoosterUpdateOneIterCustom
(
R_ExternalPtrAddr
(
handle
),
tgrad
.
data
(),
thess
.
data
(),
&
is_finished
));
return
R_NilValue
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment