"...en/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "c425a69d52c714423bbc5a55f6f3c609723993d9"
Commit c50c08dc authored by Kai Wang (Victor Kai)'s avatar Kai Wang (Victor Kai) Committed by binmakeswell
Browse files

[NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979)

parent f28c0213
#include <cooperative_groups.h>
#include <chrono> #include <chrono>
#include <ctime> #include <ctime>
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
curandStatePhilox4_32_10_t *curandstate; curandStatePhilox4_32_10_t *curandstate;
...@@ -165,8 +165,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, ...@@ -165,8 +165,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) if (i * 4 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -202,8 +201,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, ...@@ -202,8 +201,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) if (i * 8 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -261,8 +259,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, ...@@ -261,8 +259,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) if (i * 4 >= total_count) return;
return;
uint8_t m[4]; uint8_t m[4];
...@@ -289,8 +286,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, ...@@ -289,8 +286,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) if (i * 8 >= total_count) return;
return;
float4 *out4 = reinterpret_cast<float4 *>(out); float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in); const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
...@@ -380,8 +376,7 @@ __global__ void ls_dropout_res_bias_kernel( ...@@ -380,8 +376,7 @@ __global__ void ls_dropout_res_bias_kernel(
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) if (i * 4 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -424,8 +419,7 @@ __global__ void ls_dropout_res_bias_kernel( ...@@ -424,8 +419,7 @@ __global__ void ls_dropout_res_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) if (i * 8 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -565,11 +559,9 @@ __global__ void ls_dropout_bias_bwd_kernel( ...@@ -565,11 +559,9 @@ __global__ void ls_dropout_bias_bwd_kernel(
} }
__syncthreads(); __syncthreads();
for (int i = 1; i < 32; i <<= 1) for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
sum += g.shfl_down(sum, i);
if (y == 0) if (y == 0) tile[0][x] = sum;
tile[0][x] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.x < 8) { if (threadIdx.x < 8) {
...@@ -621,11 +613,9 @@ __global__ void ls_dropout_bias_bwd_kernel( ...@@ -621,11 +613,9 @@ __global__ void ls_dropout_bias_bwd_kernel(
} }
__syncthreads(); __syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
sum += g.shfl_down(sum, i);
if (y == 0) if (y == 0) tile[0][x] = sum;
tile[0][x] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.x < 8) { if (threadIdx.x < 8) {
...@@ -689,8 +679,7 @@ __global__ void ls_dropout_act_bias_kernel( ...@@ -689,8 +679,7 @@ __global__ void ls_dropout_act_bias_kernel(
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) if (i * 4 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -735,8 +724,7 @@ __global__ void ls_dropout_act_bias_kernel( ...@@ -735,8 +724,7 @@ __global__ void ls_dropout_act_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) if (i * 8 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -897,11 +885,9 @@ __global__ void ls_dropout_act_bias_bwd_kernel( ...@@ -897,11 +885,9 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float sum = tile[threadIdx.y][threadIdx.x]; float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads(); __syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
tile[0][threadIdx.y] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.y == 0) { if (threadIdx.y == 0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment