Unverified Commit 08f8f490 authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)


Co-authored-by: default avatarjianan-gu <jianan.gu@intel.com>
Co-authored-by: default avatarYanbingJiang <yanbing.jiang@intel.com>
parent d4bf5a85
......@@ -47,6 +47,45 @@ namespace {
} \
}()
// dispatch with mixed dtypes (TYPE1, TYPE2):
// TYPE1: the primary dtype (input, output, weight);
// TYPE2: the secondary dtype (bias, etc.).
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2, ...) \
[&] { \
if (TYPE2 == at::kFloat) { \
switch (TYPE1) { \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
using param_t = float; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
using param_t = float; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
} else { \
TORCH_CHECK(TYPE1 == TYPE2); \
switch (TYPE1) { \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
using param_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
using param_t = at::Half; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
} \
}()
#define UNUSED(x) (void)(x)
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
......
......@@ -252,29 +252,33 @@ void topk_softmax_kernel_impl(
});
}
template <typename scalar_t, int SIZE>
template <typename scalar_t, typename param_t, int SIZE>
inline void
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
using bVec = at::vec::Vectorized<scalar_t>;
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const param_t* __restrict__ bias) {
using fVec = at::vec::Vectorized<float>;
for (int d = 0; d < SIZE; d += bVec::size()) {
bVec bias_vec = bVec::loadu(bias + d);
fVec bias0, bias1;
std::tie(bias0, bias1) = at::vec::convert_to_float(bias_vec);
fVec x0 = fVec::loadu(scores + d) + bias0;
fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1;
using bVec = at::vec::Vectorized<scalar_t>;
auto vec_size = bVec::size();
int d = 0;
for (; d <= SIZE - vec_size; d += vec_size) {
fVec bias0, bias1, x0, x1;
std::tie(bias0, bias1) = load_float_vec2(bias + d);
std::tie(x0, x1) = load_float_vec2(scores + d);
x0 = x0 + bias0;
x1 = x1 + bias1;
x0.store(scores2 + d);
x1.store(scores2 + d + fVec::size());
}
for (; d < SIZE; d++) {
scores2[d] = scores[d] + (float)bias[d];
}
}
template <typename scalar_t, int NUM_EXPERTS, int TOPK>
template <typename scalar_t, typename param_t, int NUM_EXPERTS, int TOPK>
void biased_grouped_topk_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
const scalar_t* __restrict__ bias,
const param_t* __restrict__ bias,
int64_t num_tokens,
int64_t num_groups,
int64_t topk_group,
......@@ -295,7 +299,8 @@ void biased_grouped_topk_kernel_impl(
for (int64_t i = begin; i < end; ++i) {
// do sigmoid to get scores
sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
apply_bias<scalar_t, NUM_EXPERTS>(scores2, scores, bias);
apply_bias<scalar_t, param_t, NUM_EXPERTS>(scores2, scores, bias);
for (int64_t g = 0; g < num_groups; ++g) {
// find the max
......@@ -406,15 +411,15 @@ void biased_grouped_topk_kernel_impl(
topk, \
renormalize);
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
correction_bias.data_ptr<scalar_t>(), \
num_tokens, \
num_expert_group, \
topk_group, \
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
biased_grouped_topk_kernel_impl<scalar_t, param_t, NE, NTOPK>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
correction_bias.data_ptr<param_t>(), \
num_tokens, \
num_expert_group, \
topk_group, \
renormalize);
} // anonymous namespace
......@@ -635,7 +640,6 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
CHECK_EQ(correction_bias.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
......@@ -644,8 +648,7 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "biased_grouped_topk_kernel", [&] {
// NOW only support DSv3 configs
CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(st, correction_bias.scalar_type(), "biased_grouped_topk_kernel", [&] {
TORCH_CHECK(topk == 8, "Unexpected topk: ", topk);
switch (num_experts) {
case 256:
......
......@@ -16,6 +16,25 @@ inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, c
return at::vec::convert_from_float<scalar_t>(a, b);
}
// allow f16, bf16
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 1>
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const scalar_t* __restrict__ data) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
bVec x_vec = bVec::loadu(data);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x_vec);
return std::make_tuple(x0, x1);
}
// allow f32
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const float* __restrict__ data) {
using fVec = at::vec::Vectorized<float>;
fVec x0 = fVec::loadu(data);
fVec x1 = fVec::loadu(data + fVec::size());
return std::make_tuple(x0, x1);
}
#if defined(CPU_CAPABILITY_AVX512)
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
......
......@@ -66,13 +66,15 @@ class TestGroupedTopK(CustomTestCase):
# DeepSeek V2/V3/R1 uses biased_grouped_top
class TestBiasedGroupedTopK(CustomTestCase):
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
def _run_single_test(
self, M, E, G, topk, topk_group, renormalize, dtype, bias_dtype
):
torch.manual_seed(1234)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
correction_bias = torch.randn(E, dtype=dtype)
correction_bias = torch.randn(E, dtype=bias_dtype)
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
hidden_states.float(),
......@@ -106,7 +108,10 @@ class TestBiasedGroupedTopK(CustomTestCase):
def test_biased_grouped_topk(self):
for renormalize in [True, False]:
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
for bias_dtype in [torch.float32, torch.bfloat16]:
self._run_single_test(
122, 256, 8, 8, 2, renormalize, torch.bfloat16, bias_dtype
)
class TestTopK(CustomTestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment