"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "84c657e85df9adf7d69c68ff4b1697470f39cee3"
Unverified Commit a46c68fe authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[CUDA] Add rank_xendcg objective for cuda_exp (#5472)

parent 7c503ba3
...@@ -60,6 +60,47 @@ void CUDALambdarankNDCG::GetGradients(const double* score, score_t* gradients, s ...@@ -60,6 +60,47 @@ void CUDALambdarankNDCG::GetGradients(const double* score, score_t* gradients, s
} }
CUDARankXENDCG::CUDARankXENDCG(const Config& config): CUDALambdarankNDCG(config) {}
CUDARankXENDCG::CUDARankXENDCG(const std::vector<std::string>& strs): CUDALambdarankNDCG(strs) {}
CUDARankXENDCG::~CUDARankXENDCG() {}
void CUDARankXENDCG::Init(const Metadata& metadata, data_size_t num_data) {
CUDALambdarankNDCG::Init(metadata, num_data);
for (data_size_t i = 0; i < num_queries_; ++i) {
rands_.emplace_back(seed_ + i);
}
item_rands_.resize(num_data, 0.0f);
AllocateCUDAMemory<double>(&cuda_item_rands_, static_cast<size_t>(num_data), __FILE__, __LINE__);
if (max_items_in_query_aligned_ >= 2048) {
AllocateCUDAMemory<double>(&cuda_params_buffer_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
}
}
void CUDARankXENDCG::GenerateItemRands() const {
const int num_threads = OMP_NUM_THREADS();
OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(num_threads)
for (data_size_t i = 0; i < num_queries_; ++i) {
OMP_LOOP_EX_BEGIN();
const data_size_t start = query_boundaries_[i];
const data_size_t end = query_boundaries_[i + 1];
for (data_size_t j = start; j < end; ++j) {
item_rands_[j] = rands_[i].NextFloat();
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
}
void CUDARankXENDCG::GetGradients(const double* score, score_t* gradients, score_t* hessians) const {
GenerateItemRands();
CopyFromHostToCUDADevice<double>(cuda_item_rands_, item_rands_.data(), item_rands_.size(), __FILE__, __LINE__);
LaunchGetGradientsKernel(score, gradients, hessians);
}
} // namespace LightGBM } // namespace LightGBM
#endif // USE_CUDA_EXP #endif // USE_CUDA_EXP
...@@ -377,6 +377,282 @@ void CUDALambdarankNDCG::LaunchGetGradientsKernel(const double* score, score_t* ...@@ -377,6 +377,282 @@ void CUDALambdarankNDCG::LaunchGetGradientsKernel(const double* score, score_t*
} }
__device__ __forceinline__ double CUDAPhi(const label_t l, double g) {
return pow(2.0f, static_cast<double>(l)) - g;
}
template <size_t SHARED_MEMORY_SIZE>
__global__ void GetGradientsKernel_RankXENDCG_SharedMemory(
const double* cuda_scores,
const label_t* cuda_labels,
const double* cuda_item_rands,
const data_size_t num_data,
const data_size_t num_queries,
const data_size_t* cuda_query_boundaries,
score_t* cuda_out_gradients,
score_t* cuda_out_hessians) {
const data_size_t query_index_start = static_cast<data_size_t>(blockIdx.x) * NUM_QUERY_PER_BLOCK;
const data_size_t query_index_end = min(query_index_start + NUM_QUERY_PER_BLOCK, num_queries);
for (data_size_t query_index = query_index_start; query_index < query_index_end; ++query_index) {
const data_size_t item_index_start = cuda_query_boundaries[query_index];
const data_size_t item_index_end = cuda_query_boundaries[query_index + 1];
const data_size_t query_item_count = item_index_end - item_index_start;
score_t* cuda_out_gradients_pointer = cuda_out_gradients + item_index_start;
score_t* cuda_out_hessians_pointer = cuda_out_hessians + item_index_start;
const label_t* cuda_labels_pointer = cuda_labels + item_index_start;
const double* cuda_scores_pointer = cuda_scores + item_index_start;
const double* cuda_item_rands_pointer = cuda_item_rands + item_index_start;
const data_size_t block_reduce_size = query_item_count >= 1024 ? 1024 : query_item_count;
__shared__ double shared_rho[SHARED_MEMORY_SIZE];
// assert that warpSize == 32
__shared__ double shared_buffer[32];
__shared__ double shared_params[SHARED_MEMORY_SIZE];
__shared__ score_t shared_lambdas[SHARED_MEMORY_SIZE];
__shared__ double reduce_result;
if (query_item_count <= 1) {
for (data_size_t i = 0; i <= query_item_count; ++i) {
cuda_out_gradients_pointer[i] = 0.0f;
cuda_out_hessians_pointer[i] = 0.0f;
}
__syncthreads();
} else {
// compute softmax
double thread_reduce_result = kMinScore;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double rho = cuda_scores_pointer[i];
shared_rho[i] = rho;
if (rho > thread_reduce_result) {
thread_reduce_result = rho;
}
}
__syncthreads();
thread_reduce_result = ShuffleReduceMax<double>(thread_reduce_result, shared_buffer, block_reduce_size);
if (threadIdx.x == 0) {
reduce_result = thread_reduce_result;
}
__syncthreads();
thread_reduce_result = 0.0f;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double exp_value = exp(shared_rho[i] - reduce_result);
shared_rho[i] = exp_value;
thread_reduce_result += exp_value;
}
thread_reduce_result = ShuffleReduceSum<double>(thread_reduce_result, shared_buffer, block_reduce_size);
if (threadIdx.x == 0) {
reduce_result = thread_reduce_result;
}
__syncthreads();
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
shared_rho[i] /= reduce_result;
}
__syncthreads();
// compute params
thread_reduce_result = 0.0f;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double param_value = CUDAPhi(cuda_labels_pointer[i], cuda_item_rands_pointer[i]);
shared_params[i] = param_value;
thread_reduce_result += param_value;
}
thread_reduce_result = ShuffleReduceSum<double>(thread_reduce_result, shared_buffer, block_reduce_size);
if (threadIdx.x == 0) {
reduce_result = thread_reduce_result;
reduce_result = 1.0f / max(kEpsilon, reduce_result);
}
__syncthreads();
const double inv_denominator = reduce_result;
thread_reduce_result = 0.0f;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double term = -shared_params[i] * inv_denominator + shared_rho[i];
shared_lambdas[i] = static_cast<score_t>(term);
shared_params[i] = term / (1.0f - shared_rho[i]);
thread_reduce_result += shared_params[i];
}
thread_reduce_result = ShuffleReduceSum<double>(thread_reduce_result, shared_buffer, block_reduce_size);
if (threadIdx.x == 0) {
reduce_result = thread_reduce_result;
}
__syncthreads();
const double sum_l1 = reduce_result;
thread_reduce_result = 0.0f;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double term = shared_rho[i] * (sum_l1 - shared_params[i]);
shared_lambdas[i] += static_cast<score_t>(term);
shared_params[i] = term / (1.0f - shared_rho[i]);
thread_reduce_result += shared_params[i];
}
thread_reduce_result = ShuffleReduceSum<double>(thread_reduce_result, shared_buffer, block_reduce_size);
if (threadIdx.x == 0) {
reduce_result = thread_reduce_result;
}
__syncthreads();
const double sum_l2 = reduce_result;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
shared_lambdas[i] += static_cast<score_t>(shared_rho[i] * (sum_l2 - shared_params[i]));
cuda_out_hessians_pointer[i] = static_cast<score_t>(shared_rho[i] * (1.0f - shared_rho[i]));
}
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
cuda_out_gradients_pointer[i] = shared_lambdas[i];
}
__syncthreads();
}
}
}
__global__ void GetGradientsKernel_RankXENDCG_GlobalMemory(
const double* cuda_scores,
const label_t* cuda_labels,
const double* cuda_item_rands,
const data_size_t num_data,
const data_size_t num_queries,
const data_size_t* cuda_query_boundaries,
double* cuda_params_buffer,
score_t* cuda_out_gradients,
score_t* cuda_out_hessians) {
const data_size_t query_index_start = static_cast<data_size_t>(blockIdx.x) * NUM_QUERY_PER_BLOCK;
const data_size_t query_index_end = min(query_index_start + NUM_QUERY_PER_BLOCK, num_queries);
for (data_size_t query_index = query_index_start; query_index < query_index_end; ++query_index) {
const data_size_t item_index_start = cuda_query_boundaries[query_index];
const data_size_t item_index_end = cuda_query_boundaries[query_index + 1];
const data_size_t query_item_count = item_index_end - item_index_start;
score_t* cuda_out_gradients_pointer = cuda_out_gradients + item_index_start;
score_t* cuda_out_hessians_pointer = cuda_out_hessians + item_index_start;
const label_t* cuda_labels_pointer = cuda_labels + item_index_start;
const double* cuda_scores_pointer = cuda_scores + item_index_start;
const double* cuda_item_rands_pointer = cuda_item_rands + item_index_start;
double* cuda_params_buffer_pointer = cuda_params_buffer + item_index_start;
const data_size_t block_reduce_size = query_item_count > 1024 ? 1024 : query_item_count;
// assert that warpSize == 32, so we use buffer size 1024 / 32 = 32
__shared__ double shared_buffer[32];
__shared__ double reduce_result;
if (query_item_count <= 1) {
for (data_size_t i = 0; i <= query_item_count; ++i) {
cuda_out_gradients_pointer[i] = 0.0f;
cuda_out_hessians_pointer[i] = 0.0f;
}
__syncthreads();
} else {
// compute softmax
double thread_reduce_result = kMinScore;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double rho = cuda_scores_pointer[i];
if (rho > thread_reduce_result) {
thread_reduce_result = rho;
}
}
__syncthreads();
thread_reduce_result = ShuffleReduceMax<double>(thread_reduce_result, shared_buffer, block_reduce_size);
if (threadIdx.x == 0) {
reduce_result = thread_reduce_result;
}
__syncthreads();
thread_reduce_result = 0.0f;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double exp_value = exp(cuda_scores_pointer[i] - reduce_result);
cuda_out_hessians_pointer[i] = exp_value;
thread_reduce_result += exp_value;
}
thread_reduce_result = ShuffleReduceSum<double>(thread_reduce_result, shared_buffer, block_reduce_size);
if (threadIdx.x == 0) {
reduce_result = thread_reduce_result;
}
__syncthreads();
// store probability into hessians
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
cuda_out_hessians_pointer[i] /= reduce_result;
}
__syncthreads();
// compute params
thread_reduce_result = 0.0f;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double param_value = CUDAPhi(cuda_labels_pointer[i], cuda_item_rands_pointer[i]);
cuda_params_buffer_pointer[i] = param_value;
thread_reduce_result += param_value;
}
thread_reduce_result = ShuffleReduceSum<double>(thread_reduce_result, shared_buffer, block_reduce_size);
if (threadIdx.x == 0) {
reduce_result = thread_reduce_result;
reduce_result = 1.0f / max(kEpsilon, reduce_result);
}
__syncthreads();
const double inv_denominator = reduce_result;
thread_reduce_result = 0.0f;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double term = -cuda_params_buffer_pointer[i] * inv_denominator + cuda_out_hessians_pointer[i];
cuda_out_gradients_pointer[i] = static_cast<score_t>(term);
const double param = term / (1.0f - cuda_out_hessians_pointer[i]);
cuda_params_buffer_pointer[i] = param;
thread_reduce_result += param;
}
thread_reduce_result = ShuffleReduceSum<double>(thread_reduce_result, shared_buffer, block_reduce_size);
if (threadIdx.x == 0) {
reduce_result = thread_reduce_result;
}
__syncthreads();
const double sum_l1 = reduce_result;
thread_reduce_result = 0.0f;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double term = cuda_out_hessians_pointer[i] * (sum_l1 - cuda_params_buffer_pointer[i]);
cuda_out_gradients_pointer[i] += static_cast<score_t>(term);
const double param = term / (1.0f - cuda_out_hessians_pointer[i]);
cuda_params_buffer_pointer[i] = param;
thread_reduce_result += param;
}
thread_reduce_result = ShuffleReduceSum<double>(thread_reduce_result, shared_buffer, block_reduce_size);
if (threadIdx.x == 0) {
reduce_result = thread_reduce_result;
}
__syncthreads();
const double sum_l2 = reduce_result;
for (data_size_t i = static_cast<data_size_t>(threadIdx.x); i < query_item_count; i += static_cast<data_size_t>(blockDim.x)) {
const double prob = cuda_out_hessians_pointer[i];
cuda_out_gradients_pointer[i] += static_cast<score_t>(prob * (sum_l2 - cuda_params_buffer_pointer[i]));
cuda_out_hessians_pointer[i] = static_cast<score_t>(prob * (1.0f - prob));
}
__syncthreads();
}
}
}
void CUDARankXENDCG::LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const {
const int num_blocks = (num_queries_ + NUM_QUERY_PER_BLOCK - 1) / NUM_QUERY_PER_BLOCK;
if (max_items_in_query_aligned_ <= 1024) {
GetGradientsKernel_RankXENDCG_SharedMemory<1024><<<num_blocks, max_items_in_query_aligned_>>>(
score,
cuda_labels_,
cuda_item_rands_,
num_data_,
num_queries_,
cuda_query_boundaries_,
gradients,
hessians);
} else if (max_items_in_query_aligned_ <= 2 * 1024) {
GetGradientsKernel_RankXENDCG_SharedMemory<2 * 1024><<<num_blocks, 1024>>>(
score,
cuda_labels_,
cuda_item_rands_,
num_data_,
num_queries_,
cuda_query_boundaries_,
gradients,
hessians);
} else {
GetGradientsKernel_RankXENDCG_GlobalMemory<<<num_blocks, 1024>>>(
score,
cuda_labels_,
cuda_item_rands_,
num_data_,
num_queries_,
cuda_query_boundaries_,
cuda_params_buffer_,
gradients,
hessians);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
}
} // namespace LightGBM } // namespace LightGBM
#endif // USE_CUDA_EXP #endif // USE_CUDA_EXP
...@@ -50,6 +50,33 @@ class CUDALambdarankNDCG : public CUDAObjectiveInterface, public LambdarankNDCG ...@@ -50,6 +50,33 @@ class CUDALambdarankNDCG : public CUDAObjectiveInterface, public LambdarankNDCG
int max_items_in_query_aligned_; int max_items_in_query_aligned_;
}; };
class CUDARankXENDCG : public CUDALambdarankNDCG {
public:
explicit CUDARankXENDCG(const Config& config);
explicit CUDARankXENDCG(const std::vector<std::string>& strs);
~CUDARankXENDCG();
void Init(const Metadata& metadata, data_size_t num_data) override;
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override;
bool IsCUDAObjective() const override { return true; }
private:
void LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const;
void GenerateItemRands() const;
mutable std::vector<double> item_rands_;
mutable std::vector<Random> rands_;
mutable double* cuda_item_rands_;
mutable double* cuda_params_buffer_;
};
} // namespace LightGBM } // namespace LightGBM
#endif // USE_CUDA_EXP #endif // USE_CUDA_EXP
......
...@@ -38,8 +38,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -38,8 +38,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
} else if (type == std::string("lambdarank")) { } else if (type == std::string("lambdarank")) {
return new CUDALambdarankNDCG(config); return new CUDALambdarankNDCG(config);
} else if (type == std::string("rank_xendcg")) { } else if (type == std::string("rank_xendcg")) {
Log::Warning("Objective rank_xendcg is not implemented in cuda_exp version. Fall back to boosting on CPU."); return new CUDARankXENDCG(config);
return new RankXENDCG(config);
} else if (type == std::string("multiclass")) { } else if (type == std::string("multiclass")) {
Log::Warning("Objective multiclass is not implemented in cuda_exp version. Fall back to boosting on CPU."); Log::Warning("Objective multiclass is not implemented in cuda_exp version. Fall back to boosting on CPU.");
return new MulticlassSoftmax(config); return new MulticlassSoftmax(config);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment