"src/vscode:/vscode.git/clone" did not exist on "909742dbd6873052995dc6cd5f4150ff238015d2"
Unverified Commit 3efbdf68 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

fix sgl-kernel codestyle (#3563)

parent 6cc30955
......@@ -33,11 +33,11 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
const int batch_size, const int num_heads, const int qk_dim,
const int v_dim) {
extern __shared__ char smem[];
T* q_shared = reinterpret_cast<T*>(smem);
T* k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
T* v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
float* new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
T* output_shared =
T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
T* __restrict__ v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
float* __restrict__ new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
T* __restrict__ output_shared =
reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
const int32_t tid = threadIdx.x;
......@@ -51,6 +51,7 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
// Load q, k, v into shared memory
for (int d = tid; d < qk_dim; d += blockDim.x) {
q_shared[d] = q[qk_offset + d];
k_shared[d] = k[qk_offset + d];
......@@ -63,33 +64,36 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
const float ratio = expf(-1.0f * slope[h]);
// Compute new_kv
for (int d = tid; d < qk_dim; d += blockDim.x) {
T k_val = k_shared[d];
const T k_val = k_shared[d];
for (int e = 0; e < v_dim; ++e) {
int past_kv_idx = kv_offset + d * v_dim + e;
T v_val = v_shared[e];
float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
int shared_idx = d * (v_dim + 1) + e;
const int past_kv_idx = kv_offset + d * v_dim + e;
const T v_val = v_shared[e];
const float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
const int shared_idx = d * (v_dim + 1) + e;
new_kv_shared[shared_idx] = new_val;
}
}
__syncthreads();
// Store new_kv to global memory
for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
int d = idx / v_dim;
int e = idx % v_dim;
int shared_idx = d * (v_dim + 1) + e;
int global_idx = kv_offset + idx;
const int d = idx / v_dim;
const int e = idx % v_dim;
const int shared_idx = d * (v_dim + 1) + e;
const int global_idx = kv_offset + idx;
new_kv[global_idx] = new_kv_shared[shared_idx];
}
__syncthreads();
// Compute output
for (int e = tid; e < v_dim; e += blockDim.x) {
float sum = 0.0f;
for (int d = 0; d < qk_dim; ++d) {
int shared_idx = d * (v_dim + 1) + e;
const int shared_idx = d * (v_dim + 1) + e;
sum += q_shared[d] * new_kv_shared[shared_idx];
}
output_shared[e] = static_cast<T>(sum);
......@@ -97,6 +101,7 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
__syncthreads();
// Store output to global memory
if (tid == 0) {
for (int e = 0; e < v_dim; ++e) {
output[v_offset + e] = output_shared[e];
......
......@@ -25,8 +25,9 @@ limitations under the License.
#define WARP_SIZE 32
template <typename scalar_t>
__global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
int32_t* cumsum_buffer, size_t numel) {
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer, size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
......@@ -38,9 +39,10 @@ __global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t*
}
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* cumsum) {
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) {
__shared__ int32_t shared_counts[WARP_SIZE][8];
const int warp_id = threadIdx.x / WARP_SIZE;
......@@ -106,7 +108,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
auto sort_kernel = moe_token_sort_kernel<scalar_t>;
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
......
......@@ -7,13 +7,11 @@
using FP8_TYPE = c10::Float8_e4m3fn;
__device__ __forceinline__ float WarpReduce(volatile float* smem, const int tid) {
if (tid < 8) {
smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
}
__device__ __forceinline__ float GroupReduce(volatile float* smem, const int tid) {
smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
return smem[0];
}
......@@ -53,7 +51,7 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
// Perform reduction within each group
if (local_tid < 8) {
WarpReduce(&s_absmax[local_group_id][0], local_tid);
GroupReduce(&s_absmax[local_group_id][0], local_tid);
}
__syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment