Commit 71cac971 authored by sunchao_0511's avatar sunchao_0511
Browse files

remove rope debug && swiglu optim

parent a1937618
......@@ -126,13 +126,6 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
// 3D tensors: use 2D grid [seqlen, nhead], batch dimension is 1
grid_dim = dim3(dimx, dimy, 1);
}
// printf("block_size = %d info.table_dim = %ld has_batch_dim: %d, is_gpt_j: %d pos_has_batch_dim: %d\n",
// block_size, info.table_dim, info.has_batch_dim, is_gpt_j, info.pos_has_batch_dim);
// [batch, seqlen, nhead, dhead, table_len, table_dim, y_stride_batch, y_stride_seqlen, y_stride_nhead, x_stride_batch, x_stride_seqlen,x_stride_nhead]
// printf("[%ld %ld %ld %ld %ld %ld %ld %ld %ld %ld %ld %ld]\n", info.batch,
// info.seqlen, info.nhead, info.dhead, info.table_len, info.table_dim,
// info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
// info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
if (is_gpt_j) {
ropeThreadPerItemKernel<true><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
......@@ -154,14 +147,6 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
// ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
// y, x, pos_ids, sin_table, cos_table, info.table_dim,
// pos_stride_batch,
// info.pos_has_batch_dim,
// info.has_batch_dim,
// info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
// info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
} else {
ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
......
#ifndef __SWIGLU_CUDA_H__
#define __SWIGLU_CUDA_H__
namespace op::swiglu::cuda {
typedef struct SwiGLUOp {
private:
template <typename T>
__device__ __forceinline__ T sigmoid(const T &x) const {
if constexpr (std::is_same_v<T, half2>) {
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
float x0 = __bfloat162float(__low2bfloat16(x));
float x1 = __bfloat162float(__high2bfloat16(x));
float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0)));
float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1)));
return __floats2bfloat162_rn(sig0, sig1);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
float xf = __bfloat162float(x);
return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf))));
} else if constexpr (std::is_same_v<T, float>) {
return __frcp_rn(__fadd_rn(1, __expf(-x)));
} else {
return 1 / (1 + std::exp(-x));
}
#ifndef __SWIGLU_CUDA_KERNEL_CUH__
#define __SWIGLU_CUDA_KERNEL_CUH__
template <typename T>
__device__ __forceinline__ T sigmoid(const T &x) {
if constexpr (std::is_same_v<T, half2>) {
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
float x0 = __bfloat162float(__low2bfloat16(x));
float x1 = __bfloat162float(__high2bfloat16(x));
float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0)));
float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1)));
return __floats2bfloat162_rn(sig0, sig1);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
float xf = __bfloat162float(x);
return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf))));
} else if constexpr (std::is_same_v<T, float>) {
return __frcp_rn(__fadd_rn(1, __expf(-x)));
} else {
return 1 / (1 + std::exp(-x));
}
}
template <typename T, unsigned int BLOCK_SIZE>
__device__ void SwiGLUCudaKernel(
T *c,
const T *a,
const T *b,
int length,
size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
int ind_c = 0;
int ind_a = 0;
int ind_b = 0;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < length) {
ind_c += tid % (int)hidden_dim * (int)c_strides_2;
ind_a += tid % (int)hidden_dim * (int)a_strides_2;
ind_b += tid % (int)hidden_dim * (int)b_strides_2;
tid = tid / (int)hidden_dim;
ind_c += (tid % (int)seq_len) * (int)c_strides_1;
ind_a += (tid % (int)seq_len) * (int)a_strides_1;
ind_b += (tid % (int)seq_len) * (int)b_strides_1;
tid = tid / (int)seq_len;
ind_c += (tid % (int)batch) * (int)c_strides_0;
ind_a += (tid % (int)batch) * (int)a_strides_0;
ind_b += (tid % (int)batch) * (int)b_strides_0;
T gate = b[ind_b];
T up = a[ind_a];
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &up, const T &gate) const {
if constexpr (std::is_same_v<T, half2>) {
return __hmul2(__hmul2(gate, sigmoid(gate)), up);
c[ind_c] = __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up);
c[ind_c] = __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
cuda_bfloat162 sig = sigmoid(gate);
float gate0 = __bfloat162float(__low2bfloat16(gate));
......@@ -44,20 +66,96 @@ public:
float up1 = __bfloat162float(__high2bfloat16(up));
float res0 = __fmul_rn(__fmul_rn(gate0, sig0), up0);
float res1 = __fmul_rn(__fmul_rn(gate1, sig1), up1);
return __floats2bfloat162_rn(res0, res1);
c[ind_c] = __floats2bfloat162_rn(res0, res1);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
cuda_bfloat16 sig = sigmoid(gate);
float gatef = __bfloat162float(gate);
float sigf = __bfloat162float(sig);
float upf = __bfloat162float(up);
return __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf));
c[ind_c] = __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf));
} else if constexpr (std::is_same_v<T, float>) {
return __fmul_rn(__fmul_rn(gate, sigmoid(gate)), up);
c[ind_c] = __fmul_rn(__fmul_rn(gate, sigmoid(gate)), up);
} else {
return gate * sigmoid(gate) * up;
c[ind_c] = gate * sigmoid(gate) * up;
}
}
} SwiGLUOp;
} // namespace op::swiglu::cuda
}
__device__ void CustomSwiGLUCudaKernel(
__nv_bfloat16 *c,
const __nv_bfloat16 *a,
const __nv_bfloat16 *b,
int length,
size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= length) return;
int batchIdx = tid / (seq_len * hidden_dim);
int seqIdx = (tid - batchIdx * seq_len * hidden_dim) / hidden_dim;
int hiddenIdx = tid - (batchIdx * seq_len * hidden_dim + seqIdx * hidden_dim);
int ind_c = tid;
int ind_b = batchIdx * b_strides_0 + seqIdx * b_strides_1 + hiddenIdx * b_strides_2;
int ind_a = ind_b;
__nv_bfloat16 gate = b[ind_b];
__nv_bfloat16 up = a[ind_a];
float xf = __bfloat162float(gate);
cuda_bfloat16 sig = __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf))));
float gatef = __bfloat162float(gate);
float sigf = __bfloat162float(sig);
float upf = __bfloat162float(up);
c[ind_c] = __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf));
}
__device__ void CustomVecSwiGLUCudaKernel(
__nv_bfloat16 *c,
const __nv_bfloat16 *a,
const __nv_bfloat16 *b,
int length,
size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= length) return;
int batchIdx = tid / (seq_len * hidden_dim);
int seqIdx = (tid - batchIdx * seq_len * hidden_dim) / hidden_dim;
int hiddenIdx = tid - (batchIdx * seq_len * hidden_dim + seqIdx * hidden_dim);
// int ind_c = (batchIdx * c_strides_0 + seqIdx * c_strides_1 + hiddenIdx * c_strides_2) << 3;
int ind_c = tid << 3;
int ind_b = (batchIdx * b_strides_0 + seqIdx * b_strides_1 + hiddenIdx * b_strides_2) << 3;
int ind_a = ind_b;
__nv_bfloat16 gate[8];
__nv_bfloat16 up[8];
__nv_bfloat16 output[8];
const float4* global_gate = reinterpret_cast<const float4*>(b + ind_b);
const float4* global_up = reinterpret_cast<const float4*>(a + ind_a);
float4* global_output = reinterpret_cast<float4*>(c + ind_c);
float4 gate_val = *global_gate;
float4 up_val = *global_up;
*reinterpret_cast<float4*>(gate) = gate_val;
*reinterpret_cast<float4*>(up) = up_val;
#pragma unroll
for (int i = 0; i < 8; i++) {
float xf = __bfloat162float(gate[i]);
__nv_bfloat16 sig = __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf))));
float gatef = __bfloat162float(gate[i]);
float sigf = __bfloat162float(sig);
float upf = __bfloat162float(up[i]);
output[i] = __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf));
}
*global_output = *reinterpret_cast<float4*>(output);
}
#endif // __SWIGLU_CUDA_H__
#endif // __SWIGLU_CUDA_KERNEL_CUH__
......@@ -19,6 +19,45 @@ INFINIOP_CUDA_KERNEL SwiGLUCuda(
b_strides_0, b_strides_1, b_strides_2);
}
INFINIOP_CUDA_KERNEL CustomSwiGLUCuda(
__nv_bfloat16 *c,
const __nv_bfloat16 *a,
const __nv_bfloat16 *b,
int length,
size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
CustomSwiGLUCudaKernel(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1, b_strides_2);
}
INFINIOP_CUDA_KERNEL CustomVecSwiGLUCuda(
__nv_bfloat16 *c,
const __nv_bfloat16 *a,
const __nv_bfloat16 *b,
int length,
size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
const int VEC_FACTOR = 8;
hidden_dim /= VEC_FACTOR;
c_strides_0 /= VEC_FACTOR;
c_strides_1 /= VEC_FACTOR;
a_strides_0 /= VEC_FACTOR;
a_strides_1 /= VEC_FACTOR;
b_strides_0 /= VEC_FACTOR;
b_strides_1 /= VEC_FACTOR;
length /= VEC_FACTOR;
CustomVecSwiGLUCudaKernel(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1, b_strides_2);
}
namespace op::swiglu_cuda::nvidia {
struct Descriptor::Opaque {
......@@ -68,13 +107,39 @@ infiniStatus_t calculate_swiglu_cuda(
ptrdiff_t b_strides_1 = info.b_strides_1;
ptrdiff_t b_strides_2 = info.b_strides_2;
int num_blocks = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
SwiGLUCuda<T, BLOCK_SIZE>
bool vec_flag = false;
//向量化取数据在这个长度下性能才最优
if ((hidden_dim % 8 == 0) && length >= 295680) {
vec_flag = true;
}
if (std::is_same<T, __nv_bfloat16>::value) {
auto bf16_c = reinterpret_cast<__nv_bfloat16*>(c);
auto bf16_a = reinterpret_cast<const __nv_bfloat16*>(a);
auto bf16_b = reinterpret_cast<const __nv_bfloat16*>(b);
if(vec_flag) {
int block_size = 256;
int grid_size = (length / 8 + block_size - 1) / block_size;
CustomVecSwiGLUCuda<<<grid_size, block_size, 0, stream>>>(bf16_c, bf16_a, bf16_b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1, b_strides_2);
} else {
int block_size = 256;
int grid_size = (length + block_size - 1) / block_size;
CustomSwiGLUCuda<<<grid_size, block_size, 0, stream>>>(bf16_c, bf16_a, bf16_b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1, b_strides_2);
}
} else {
int num_blocks = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
SwiGLUCuda<T, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1, b_strides_2);
}
return INFINI_STATUS_SUCCESS;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment