Unverified Commit ff00895c authored by jianan-gu's avatar jianan-gu Committed by GitHub
Browse files

Add CPU optimized kernels for topk and rope fusions (#6456)

parent ff914748
......@@ -4,6 +4,67 @@
namespace {
// NB: avoid using `at::vec::map<>` on bfloat16 or half
// Llama4TextL2Norm
template <typename scalar_t>
void l2norm_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
int64_t batch_size,
int64_t hidden_size,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
fVec sum_fvec = fVec(float(0));
float sum_val = float(0);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
sum_val += x_val * x_val;
}
sum_val += vec_reduce_sum(sum_fvec);
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
const fVec scale_fvec = fVec(rsqrt_var);
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
x_fvec0 = x_fvec0 * scale_fvec;
x_fvec1 = x_fvec1 * scale_fvec;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(out_ptr + d);
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var);
}
}
});
}
template <typename scalar_t>
void rmsnorm_kernel_impl(
scalar_t* __restrict__ output,
......@@ -160,6 +221,22 @@ void fused_add_rmsnorm_kernel_impl(
} // anonymous namespace
// input : {batch_size, hidden_size}
at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
RECORD_FUNCTION("sgl-kernel::l2norm_cpu", std::vector<c10::IValue>({input}));
CHECK_INPUT(input);
CHECK_DIM(2, input);
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
at::Tensor output = at::empty_like(input);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "l2norm_kernel", [&] {
l2norm_kernel_impl<scalar_t>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), batch_size, hidden_size, eps);
});
return output;
}
// input : {batch_size, hidden_size}
// weight: {hidden_size}
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
......
......@@ -4,126 +4,343 @@
namespace {
template <typename scalar_t>
void rope_kernel_impl(
scalar_t* __restrict__ q_pe_out,
scalar_t* __restrict__ k_pe_out,
int64_t* __restrict__ t_pos,
scalar_t* __restrict__ q_pe,
scalar_t* __restrict__ k_pe,
scalar_t* __restrict__ t_emb_pos,
int64_t seq_len,
int64_t num_head,
void rotary_embedding_3D_kernel_impl(
scalar_t* __restrict__ query_out,
scalar_t* __restrict__ key_out,
int64_t* __restrict__ positions,
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
scalar_t* __restrict__ cos_sin_cache,
int64_t num_tokens,
int64_t num_heads,
int64_t num_kv_heads,
int64_t head_size,
int64_t rotary_dim,
int64_t HR,
int64_t q_pe_stride_s,
int64_t out_stride_qs,
int64_t out_stride_ks,
int64_t HK,
int64_t k_pe_stride_s,
int64_t q_pe_stride_n,
int64_t out_stride_qn) {
int64_t query_stride_s,
int64_t query_out_stride_s,
int64_t key_out_stride_s,
int64_t key_stride_s,
int64_t query_stride_h,
int64_t query_out_stride_h) {
int64_t HR = rotary_dim;
int64_t HK = rotary_dim;
int64_t COFF = HR / 2;
at::parallel_for(0, seq_len * num_head, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t seq{0}, head_id{0};
data_index_init(begin, seq, seq_len, head_id, num_head);
data_index_init(begin, seq, num_tokens, head_id, num_heads);
for (int64_t i = begin; i < end; ++i) {
int64_t in_offset_q = seq * q_pe_stride_s + head_id * q_pe_stride_n;
int64_t out_offset_q = seq * out_stride_qs + head_id * out_stride_qn;
int64_t out_offset_k = seq * out_stride_ks;
int64_t in_offset_q = seq * query_stride_s + head_id * query_stride_h;
int64_t out_offset_q = seq * query_out_stride_s + head_id * query_out_stride_h;
int64_t out_offset_k = seq * key_out_stride_s;
int64_t p = 0;
scalar_t* sin_start = nullptr;
scalar_t* cos_start = nullptr;
// step 0) get the rotary position embedding for the current position
p = t_pos[seq];
sin_start = t_emb_pos + p * HR + COFF;
cos_start = t_emb_pos + p * HR;
p = positions[seq];
sin_start = cos_sin_cache + p * HR + COFF;
cos_start = cos_sin_cache + p * HR;
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
// head of query/key
for (int64_t h = 0; h < rotary_dim; h += 2) {
scalar_t cos = cos_start[h >> 1];
scalar_t sin = sin_start[h >> 1];
scalar_t in1 = q_pe[in_offset_q + h];
scalar_t in2 = q_pe[in_offset_q + h + 1];
scalar_t in1 = query[in_offset_q + h];
scalar_t in2 = query[in_offset_q + h + 1];
scalar_t out1 = in1 * cos - in2 * sin;
scalar_t out2 = in2 * cos + in1 * sin;
q_pe_out[out_offset_q + h] = out1;
q_pe_out[out_offset_q + h + 1] = out2;
query_out[out_offset_q + h] = out1;
query_out[out_offset_q + h + 1] = out2;
}
for (int64_t h = 0; h < HK; h += 2) {
scalar_t cos = cos_start[h >> 1];
scalar_t sin = sin_start[h >> 1];
int64_t k_pe_offset = seq * k_pe_stride_s;
scalar_t in1_k = k_pe[k_pe_offset + h];
scalar_t in2_k = k_pe[k_pe_offset + h + 1];
int64_t k_pe_offset = seq * key_stride_s;
scalar_t in1_k = key[k_pe_offset + h];
scalar_t in2_k = key[k_pe_offset + h + 1];
scalar_t out1_k = in1_k * cos - in2_k * sin;
scalar_t out2_k = in2_k * cos + in1_k * sin;
k_pe_out[out_offset_k + h] = out1_k;
k_pe_out[out_offset_k + h + 1] = out2_k;
key_out[out_offset_k + h] = out1_k;
key_out[out_offset_k + h + 1] = out2_k;
}
// move to the next index
data_index_step(seq, seq_len, head_id, num_head);
data_index_step(seq, num_tokens, head_id, num_heads);
}
});
}
template <typename scalar_t>
void rotary_embedding_neox_2D_kernel_impl(
int64_t* __restrict__ positions,
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
scalar_t* __restrict__ cos_sin_cache,
int64_t rotary_dim,
int64_t query_stride_s,
int64_t key_stride_s,
int64_t num_heads,
int64_t num_kv_heads,
int64_t head_size,
int64_t num_tokens) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int64_t bVecSize = bVec::size();
int64_t embed_dim = rotary_dim / 2;
bool flag = (embed_dim % bVecSize == 0);
int64_t loop_upper = flag ? embed_dim : embed_dim - bVecSize;
auto compute_loop = [&](int64_t token_head, scalar_t* cache_ptr, scalar_t* qk) {
int64_t j = 0;
for (; j < loop_upper; j += bVecSize) {
int64_t rot_offset = j;
int64_t x_index = rot_offset;
int64_t y_index = embed_dim + rot_offset;
int64_t out_x = token_head + x_index;
int64_t out_y = token_head + y_index;
bVec _cos = bVec::loadu(cache_ptr + x_index);
bVec _sin = bVec::loadu(cache_ptr + y_index);
bVec _q_x = bVec::loadu(qk + out_x);
bVec _q_y = bVec::loadu(qk + out_y);
fVec _cos_0, _cos_1;
std::tie(_cos_0, _cos_1) = at::vec::convert_to_float(_cos);
fVec _sin_0, _sin_1;
std::tie(_sin_0, _sin_1) = at::vec::convert_to_float(_sin);
fVec _q_x_0, _q_x_1;
std::tie(_q_x_0, _q_x_1) = at::vec::convert_to_float(_q_x);
fVec _q_y_0, _q_y_1;
std::tie(_q_y_0, _q_y_1) = at::vec::convert_to_float(_q_y);
auto out1_0 = _q_x_0 * _cos_0 - _q_y_0 * _sin_0;
auto out1_1 = _q_x_1 * _cos_1 - _q_y_1 * _sin_1;
auto out1 = convert_from_float_ext<scalar_t>(out1_0, out1_1);
out1.store(qk + out_x);
auto out2_0 = _q_y_0 * _cos_0 + _q_x_0 * _sin_0;
auto out2_1 = _q_y_1 * _cos_1 + _q_x_1 * _sin_1;
auto out2 = convert_from_float_ext<scalar_t>(out2_0, out2_1);
out2.store(qk + out_y);
}
if (!flag) {
for (; j < embed_dim; ++j) {
int64_t x_index = j;
int64_t y_index = embed_dim + j;
int64_t out_x = token_head + x_index;
int64_t out_y = token_head + y_index;
float _cos = cache_ptr[x_index];
float _sin = cache_ptr[y_index];
float _q_x = qk[out_x];
float _q_y = qk[out_y];
qk[out_x] = _q_x * _cos - _q_y * _sin;
qk[out_y] = _q_y * _cos + _q_x * _sin;
}
}
};
#pragma omp parallel for
for (int64_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx];
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
for (int64_t i = 0; i < num_heads; ++i) {
int64_t head_idx = i;
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
compute_loop(token_head, cache_ptr, query);
}
for (int64_t i = 0; i < num_kv_heads; ++i) {
int64_t head_idx = i;
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
}
}
}
template <typename scalar_t>
void rotary_embedding_2D_kernel_impl(
int64_t* __restrict__ positions,
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
scalar_t* __restrict__ cos_sin_cache,
int64_t rotary_dim,
int64_t query_stride_s,
int64_t key_stride_s,
int64_t num_heads,
int64_t num_kv_heads,
int64_t head_size,
int64_t num_tokens) {
int64_t embed_dim = rotary_dim / 2;
at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t token_idx = {0}, i = {0};
data_index_init(begin, token_idx, num_tokens, i, num_heads);
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
int64_t pos = positions[token_idx];
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
scalar_t* cos_cache_ptr = cache_ptr;
scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
int64_t head_idx = i;
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
scalar_t* head_query = token_head + query;
for (int64_t j = 0; j < embed_dim; j += 1) {
int64_t rot_offset = j;
int64_t x_index = 2 * rot_offset;
int64_t y_index = 2 * rot_offset + 1;
float cos = cos_cache_ptr[rot_offset];
float sin = sin_cache_ptr[rot_offset];
float x = head_query[x_index];
float y = head_query[y_index];
head_query[x_index] = x * cos - y * sin;
head_query[y_index] = y * cos + x * sin;
}
data_index_step(token_idx, num_tokens, i, num_heads);
}
});
at::parallel_for(0, num_tokens * num_kv_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t token_idx{0}, i = {0};
data_index_init(begin, token_idx, num_tokens, i, num_kv_heads);
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
int64_t pos = positions[token_idx];
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
scalar_t* cos_cache_ptr = cache_ptr;
scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
int64_t head_idx = i;
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
scalar_t* head_key = key + token_head;
for (int64_t j = 0; j < embed_dim; j += 1) {
int64_t rot_offset = j;
int64_t x_index = 2 * rot_offset;
int64_t y_index = 2 * rot_offset + 1;
float cos = cos_cache_ptr[rot_offset];
float sin = sin_cache_ptr[rot_offset];
float x = head_key[x_index];
float y = head_key[y_index];
head_key[x_index] = x * cos - y * sin;
head_key[y_index] = y * cos + x * sin;
}
data_index_step(token_idx, num_tokens, i, num_kv_heads);
}
});
}
} // namespace
std::tuple<at::Tensor, at::Tensor>
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos) {
RECORD_FUNCTION(
"sgl-kernel::rotary_position_embedding_cpu", std::vector<c10::IValue>({t_pos, q_pe, k_pe, t_emb_pos}));
CHECK_INPUT(t_pos);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_pe);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_pe);
CHECK_INPUT(t_emb_pos);
CHECK_DIM(1, t_pos);
CHECK_DIM(3, q_pe);
CHECK_DIM(3, k_pe);
CHECK_DIM(2, t_emb_pos);
int64_t seq_len = q_pe.size(0);
int64_t num_head = q_pe.size(1);
int64_t rotary_dim = q_pe.size(2);
int64_t HK = k_pe.size(2);
int64_t HR = t_emb_pos.size(1);
CHECK_EQ(HR, rotary_dim);
CHECK_EQ(k_pe.size(0), seq_len);
CHECK_EQ(k_pe.size(1), 1);
CHECK_EQ(t_pos.size(0), seq_len);
CHECK_EQ(HK, rotary_dim);
at::Tensor q_pe_out = at::empty_like(q_pe);
at::Tensor k_pe_out = at::empty_like(k_pe);
int64_t q_pe_stride_s = q_pe.stride(0);
int64_t q_pe_stride_n = q_pe.stride(1);
int64_t k_pe_stride_s = k_pe.stride(0);
int64_t out_stride_qs = q_pe_out.stride(0);
int64_t out_stride_qn = q_pe_out.stride(1);
int64_t out_stride_ks = k_pe_out.stride(0);
const auto input_dtype = q_pe.scalar_type();
TORCH_CHECK(t_pos.scalar_type() == at::kLong, "expect positions to be int64, got ", t_pos.scalar_type());
TORCH_CHECK(input_dtype == k_pe.scalar_type(), "q_pe and k_pe must have the same data type");
TORCH_CHECK(input_dtype == t_emb_pos.scalar_type(), "q_pe and t_emb_pos must have the same data type");
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_position_embedding_cpu", [&] {
rope_kernel_impl<scalar_t>(
q_pe_out.data_ptr<scalar_t>(),
k_pe_out.data_ptr<scalar_t>(),
t_pos.data_ptr<int64_t>(),
q_pe.data_ptr<scalar_t>(),
k_pe.data_ptr<scalar_t>(),
t_emb_pos.data_ptr<scalar_t>(),
seq_len,
num_head,
rotary_dim,
HR,
q_pe_stride_s,
out_stride_qs,
out_stride_ks,
HK,
k_pe_stride_s,
q_pe_stride_n,
out_stride_qn);
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
at::Tensor& positions,
at::Tensor& query,
at::Tensor& key,
int64_t head_size,
at::Tensor& cos_sin_cache,
bool is_neox) {
RECORD_FUNCTION("sgl-kernel::rotary_embedding_cpu", std::vector<c10::IValue>({query, key}));
CHECK_DIM(1, positions);
const auto input_dim = query.dim();
const auto input_dtype = query.scalar_type();
TORCH_CHECK(
input_dim == 2 || input_dim == 3,
" Query/Key must be 2D [num_tokens, num_heads*head_size] or 3D [num_tokens, num_heads, head_size] tensor");
CHECK_DIM(2, cos_sin_cache);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(query);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(key);
int64_t rotary_dim = cos_sin_cache.size(1);
if (input_dim == 3) {
// TODO: add support for head_dim != rotary_dim case when input_dim=3
CHECK_EQ(query.size(-1), rotary_dim);
// TODO: add support for kv_head != 1
CHECK_EQ(key.size(1), 1);
}
int64_t num_tokens = positions.numel();
CHECK_EQ(key.size(0), num_tokens);
CHECK_EQ(query.size(0), num_tokens);
TORCH_CHECK(positions.scalar_type() == at::kLong, "expect positions to be int64, got ", positions.scalar_type());
TORCH_CHECK(input_dtype == key.scalar_type(), "query and key must have the same data type");
TORCH_CHECK(input_dtype == cos_sin_cache.scalar_type(), "query and cos_sin_cache must have the same data type");
int64_t num_heads = input_dim == 2 ? query.size(-1) / head_size : query.size(1);
int64_t num_kv_heads = input_dim == 2 ? key.size(-1) / head_size : key.size(1);
int64_t key_stride_s = key.stride(0);
int64_t query_stride_s = query.stride(0);
// input stride of num head dim is meaningful only when input dim = 3
int64_t query_stride_h = input_dim == 3 ? query.stride(1) : -1;
at::Tensor query_out = at::empty_like(query);
at::Tensor key_out = at::empty_like(key);
int64_t query_out_stride_s = query_out.stride(0);
int64_t key_out_stride_s = key_out.stride(0);
// output stride of num head dim is meaningful only when input dim = 3
int64_t query_out_stride_h = input_dim == 3 ? query_out.stride(1) : -1;
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_embedding_cpu", [&] {
if (input_dim == 2) {
if (is_neox) {
rotary_embedding_neox_2D_kernel_impl<scalar_t>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rotary_dim,
query_stride_s,
key_stride_s,
num_heads,
num_kv_heads,
head_size,
num_tokens);
} else {
rotary_embedding_2D_kernel_impl<scalar_t>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rotary_dim,
query_stride_s,
key_stride_s,
num_heads,
num_kv_heads,
head_size,
num_tokens);
}
query_out = query;
key_out = key;
} else {
TORCH_CHECK(
is_neox == false, " Query/Key with 3D [num_tokens, num_heads, head_size] does not support neox rope yet");
// TODO: add neox style support for rope impl with 3D inputs
rotary_embedding_3D_kernel_impl<scalar_t>(
query_out.data_ptr<scalar_t>(),
key_out.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
num_tokens,
num_heads,
num_kv_heads,
head_size,
rotary_dim,
query_stride_s,
query_out_stride_s,
key_out_stride_s,
key_stride_s,
query_stride_h,
query_out_stride_h);
}
});
return std::make_tuple(q_pe_out, k_pe_out);
return std::make_tuple(query_out, key_out);
}
......@@ -157,6 +157,101 @@ inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input)
}
}
template <typename scalar_t, int NUM_EXPERTS>
void topk_sigmoid_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
int64_t num_tokens,
int64_t topk,
bool renormalize) {
using Vec = at::vec::Vectorized<float>;
const int64_t num_experts_per_group = NUM_EXPERTS;
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
alignas(64) float scores[NUM_EXPERTS];
using elem_t = std::pair<float, int32_t>;
std::vector<elem_t> queue(num_experts_per_group);
for (int64_t i = begin; i < end; ++i) {
at::vec::convert<scalar_t, float>(gating_output + i * NUM_EXPERTS, scores, NUM_EXPERTS);
float gmax = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, scores, num_experts_per_group);
// find position of first max,
// note that we may have multiple max values.
int first_max_idx = -1;
for (int64_t e = 0; e < num_experts_per_group; ++e) {
if (scores[e] == gmax) {
first_max_idx = e;
break;
}
}
// scalar sigmoid
topk_weights[i] = 1.0 / (1.0 + exp(0.0 - gmax));
topk_ids[i] = first_max_idx;
if (renormalize) {
float sum = 0.f;
for (int64_t j = 0; j < topk; ++j) {
sum += topk_weights[i * topk + j];
}
float scale = 1.f / sum;
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] *= scale;
}
}
}
});
}
template <typename scalar_t, int NUM_EXPERTS>
void topk_softmax_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
int64_t num_tokens,
int64_t topk,
bool renormalize) {
const int64_t num_experts_per_group = NUM_EXPERTS;
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
alignas(64) float scores[NUM_EXPERTS];
using elem_t = std::pair<float, int32_t>;
std::vector<elem_t> queue(num_experts_per_group);
for (int64_t i = begin; i < end; ++i) {
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
for (int64_t e = 0; e < num_experts_per_group; ++e) {
queue[e] = {scores[e], e};
}
std::partial_sort(
queue.begin(),
queue.begin() + num_experts_per_group,
queue.end(),
[](const elem_t& x, const elem_t& y) -> bool { return x.first > y.first; });
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] = queue[j].first;
topk_ids[i * topk + j] = queue[j].second;
}
if (renormalize) {
float sum = 0.f;
for (int64_t j = 0; j < topk; ++j) {
sum += topk_weights[i * topk + j];
}
float scale = 1.f / sum;
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] *= scale;
}
}
}
});
}
template <typename scalar_t, int SIZE>
inline void
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
......@@ -293,6 +388,24 @@ void biased_grouped_topk_kernel_impl(
topk_group, \
renormalize);
#define LAUNCH_TOPK_SIGMOID_KERNEL(NE) \
topk_sigmoid_kernel_impl<scalar_t, NE>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
num_tokens, \
topk, \
renormalize);
#define LAUNCH_TOPK_SOFTMAX_KERNEL(NE) \
topk_softmax_kernel_impl<scalar_t, NE>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
num_tokens, \
topk, \
renormalize);
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
topk_weights.data_ptr<float>(), \
......@@ -306,6 +419,114 @@ void biased_grouped_topk_kernel_impl(
} // anonymous namespace
std::tuple<at::Tensor, at::Tensor>
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
RECORD_FUNCTION("sgl-kernel::topk_sigmoid_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
CHECK_INPUT(gating_output);
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
TORCH_CHECK(topk == 1, "topk_sigmoid only supports topk=1 case");
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_sigmoid_kernel", [&] {
switch (num_experts) {
case 1:
LAUNCH_TOPK_SIGMOID_KERNEL(1);
break;
case 2:
LAUNCH_TOPK_SIGMOID_KERNEL(2);
break;
case 4:
LAUNCH_TOPK_SIGMOID_KERNEL(4);
break;
case 8:
LAUNCH_TOPK_SIGMOID_KERNEL(8);
break;
case 16:
LAUNCH_TOPK_SIGMOID_KERNEL(16);
break;
case 32:
LAUNCH_TOPK_SIGMOID_KERNEL(32);
break;
case 64:
LAUNCH_TOPK_SIGMOID_KERNEL(64);
break;
case 128:
LAUNCH_TOPK_SIGMOID_KERNEL(128);
break;
case 160:
LAUNCH_TOPK_SIGMOID_KERNEL(160);
break;
case 256:
LAUNCH_TOPK_SIGMOID_KERNEL(256);
break;
default:
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
}
});
return std::make_tuple(topk_weights, topk_ids);
}
std::tuple<at::Tensor, at::Tensor>
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
RECORD_FUNCTION("sgl-kernel::topk_softmax_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
CHECK_INPUT(gating_output);
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_softmax_cpu", [&] {
switch (num_experts) {
case 1:
LAUNCH_TOPK_SOFTMAX_KERNEL(1);
break;
case 2:
LAUNCH_TOPK_SOFTMAX_KERNEL(2);
break;
case 4:
LAUNCH_TOPK_SOFTMAX_KERNEL(4);
break;
case 8:
LAUNCH_TOPK_SOFTMAX_KERNEL(8);
break;
case 16:
LAUNCH_TOPK_SOFTMAX_KERNEL(16);
break;
case 32:
LAUNCH_TOPK_SOFTMAX_KERNEL(32);
break;
case 64:
LAUNCH_TOPK_SOFTMAX_KERNEL(64);
break;
case 128:
LAUNCH_TOPK_SOFTMAX_KERNEL(128);
break;
case 160:
LAUNCH_TOPK_SOFTMAX_KERNEL(160);
break;
case 256:
LAUNCH_TOPK_SOFTMAX_KERNEL(256);
break;
default:
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
}
});
return std::make_tuple(topk_weights, topk_ids);
}
// grouped topk for DeepSeek V2
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
at::Tensor& hidden_states,
......
......@@ -23,6 +23,9 @@ limitations under the License.
// silu_and_mul
at::Tensor silu_and_mul_cpu(at::Tensor& input);
// l2norm
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
// rmsnorm
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
......@@ -30,6 +33,11 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);
// topk
std::tuple<at::Tensor, at::Tensor>
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
std::tuple<at::Tensor, at::Tensor>
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
at::Tensor& hidden_states,
at::Tensor& gating_output,
......@@ -185,8 +193,13 @@ void shm_allreduce(
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
// rope
std::tuple<at::Tensor, at::Tensor>
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
at::Tensor& positions,
at::Tensor& query,
at::Tensor& key,
int64_t head_size,
at::Tensor& cos_sin_cache,
bool is_neox);
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// activation
......@@ -196,10 +209,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// norm
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
// topk
m.def("topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
m.impl("topk_sigmoid_cpu", torch::kCPU, &topk_sigmoid_cpu);
m.def("topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu);
m.def(
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
"int topk_group) -> (Tensor, Tensor)");
......@@ -294,8 +313,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
// rope
m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)");
m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu);
m.def(
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
"bool is_neox) -> (Tensor, Tensor)");
m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);
}
REGISTER_EXTENSION(common_ops)
......@@ -63,10 +63,24 @@ class TestNorm(CustomTestCase):
self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol))
def _l2norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
hidden_size = x.size(-1)
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
def test_norm(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._norm_test(*params)
self._l2norm_test(*params)
if __name__ == "__main__":
......
......@@ -4,7 +4,10 @@ import sgl_kernel
import torch
from utils import precision
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
from sglang.srt.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding,
RotaryEmbedding,
)
from sglang.test.test_utils import CustomTestCase
......@@ -62,10 +65,13 @@ class TestROPE(CustomTestCase):
)
# fused rope kernel
q_pe_clone, k_pe_clone = (
torch.ops.sgl_kernel.rotary_position_embedding_cpu(
positions, q_pe_clone, k_pe_clone, cos_sin_cache
)
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
positions,
q_pe_clone,
k_pe_clone,
rope.head_size,
cos_sin_cache,
False,
)
atol = rtol = precision[q_pe.dtype]
......@@ -73,6 +79,98 @@ class TestROPE(CustomTestCase):
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
torch.testing.assert_close(k_pe, k_pe_clone)
def test_origin_rope(self):
def single_test(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q_heads: int,
num_kv_heads: int,
):
torch.manual_seed(100)
rope_ref = RotaryEmbedding(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
).to(device)
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len,
num_q_heads * head_size,
dtype=dtype,
device=device,
)
key = torch.randn(
batch_size * seq_len,
num_kv_heads * head_size,
dtype=dtype,
device=device,
)
query_ref, key_ref = query.clone(), key.clone()
query_cpu, key_cpu = query.clone(), key.clone()
query_ref_out, key_ref_out = rope_ref.forward_native(
pos_ids, query_ref, key_ref
)
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
pos_ids,
query_cpu,
key_cpu,
rope_ref.head_size,
rope_ref.cos_sin_cache.to(query.dtype),
rope_ref.is_neox_style,
)
torch.testing.assert_close(
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
test_config = [
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
]
for (
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
batch_size,
seq_len,
num_q_heads,
num_kv_heads,
) in test_config:
single_test(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
batch_size,
seq_len,
num_q_heads,
num_kv_heads,
)
if __name__ == "__main__":
unittest.main()
......@@ -8,7 +8,9 @@ from utils import precision
from sglang.srt.layers.moe.topk import (
biased_grouped_topk_impl as native_biased_grouped_topk,
)
from sglang.srt.layers.moe.topk import fused_topk_native as native_fused_topk
from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk
from sglang.srt.models.llama4 import Llama4MoE
from sglang.test.test_utils import CustomTestCase
......@@ -94,5 +96,86 @@ class TestBiasedGroupedTopK(CustomTestCase):
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
class TestTopK(CustomTestCase):
def _run_single_test(self, M, E, topk, renormalize, dtype):
torch.manual_seed(1998)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
ref_topk_weights, ref_topk_ids = native_fused_topk(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
)
# fused version
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states, gating_output, topk, renormalize
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_topk(self):
for renormalize in [True, False]:
self._run_single_test(123, 8, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 16, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 64, 6, renormalize, torch.bfloat16)
self._run_single_test(123, 256, 4, renormalize, torch.bfloat16)
self._run_single_test(123, 160, 6, renormalize, torch.bfloat16)
class TestCustomTopK(CustomTestCase):
def _run_single_test(
self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f
):
torch.manual_seed(16)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
ref_topk_weights, ref_topk_ids = native_custom_f(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
)
# fused version
topk_weights, topk_ids = fused_custom_f(
hidden_states, gating_output, topk, renormalize
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_custom_topk(self):
test_custom_functions = [
(Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu)
]
for native_custom_f, fused_custom_f in test_custom_functions:
self._run_single_test(
123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
self._run_single_test(
123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
self._run_single_test(
123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment