Unverified Commit 08f8f490 authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)


Co-authored-by: default avatarjianan-gu <jianan.gu@intel.com>
Co-authored-by: default avatarYanbingJiang <yanbing.jiang@intel.com>
parent d4bf5a85
...@@ -47,6 +47,45 @@ namespace { ...@@ -47,6 +47,45 @@ namespace {
} \ } \
}() }()
// dispatch with mixed dtypes (TYPE1, TYPE2):
// TYPE1: the primary dtype (input, output, weight);
// TYPE2: the secondary dtype (bias, etc.).
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2, ...) \
[&] { \
if (TYPE2 == at::kFloat) { \
switch (TYPE1) { \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
using param_t = float; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
using param_t = float; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
} else { \
TORCH_CHECK(TYPE1 == TYPE2); \
switch (TYPE1) { \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
using param_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
using param_t = at::Half; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
} \
}()
#define UNUSED(x) (void)(x) #define UNUSED(x) (void)(x)
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") #define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
......
...@@ -252,29 +252,33 @@ void topk_softmax_kernel_impl( ...@@ -252,29 +252,33 @@ void topk_softmax_kernel_impl(
}); });
} }
template <typename scalar_t, int SIZE> template <typename scalar_t, typename param_t, int SIZE>
inline void inline void
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) { apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const param_t* __restrict__ bias) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>; using fVec = at::vec::Vectorized<float>;
for (int d = 0; d < SIZE; d += bVec::size()) { using bVec = at::vec::Vectorized<scalar_t>;
bVec bias_vec = bVec::loadu(bias + d); auto vec_size = bVec::size();
fVec bias0, bias1; int d = 0;
std::tie(bias0, bias1) = at::vec::convert_to_float(bias_vec); for (; d <= SIZE - vec_size; d += vec_size) {
fVec bias0, bias1, x0, x1;
fVec x0 = fVec::loadu(scores + d) + bias0; std::tie(bias0, bias1) = load_float_vec2(bias + d);
fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1; std::tie(x0, x1) = load_float_vec2(scores + d);
x0 = x0 + bias0;
x1 = x1 + bias1;
x0.store(scores2 + d); x0.store(scores2 + d);
x1.store(scores2 + d + fVec::size()); x1.store(scores2 + d + fVec::size());
} }
for (; d < SIZE; d++) {
scores2[d] = scores[d] + (float)bias[d];
}
} }
template <typename scalar_t, int NUM_EXPERTS, int TOPK> template <typename scalar_t, typename param_t, int NUM_EXPERTS, int TOPK>
void biased_grouped_topk_kernel_impl( void biased_grouped_topk_kernel_impl(
float* __restrict__ topk_weights, float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids, int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output, const scalar_t* __restrict__ gating_output,
const scalar_t* __restrict__ bias, const param_t* __restrict__ bias,
int64_t num_tokens, int64_t num_tokens,
int64_t num_groups, int64_t num_groups,
int64_t topk_group, int64_t topk_group,
...@@ -295,7 +299,8 @@ void biased_grouped_topk_kernel_impl( ...@@ -295,7 +299,8 @@ void biased_grouped_topk_kernel_impl(
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
// do sigmoid to get scores // do sigmoid to get scores
sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS); sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
apply_bias<scalar_t, NUM_EXPERTS>(scores2, scores, bias);
apply_bias<scalar_t, param_t, NUM_EXPERTS>(scores2, scores, bias);
for (int64_t g = 0; g < num_groups; ++g) { for (int64_t g = 0; g < num_groups; ++g) {
// find the max // find the max
...@@ -407,11 +412,11 @@ void biased_grouped_topk_kernel_impl( ...@@ -407,11 +412,11 @@ void biased_grouped_topk_kernel_impl(
renormalize); renormalize);
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \ #define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \ biased_grouped_topk_kernel_impl<scalar_t, param_t, NE, NTOPK>( \
topk_weights.data_ptr<float>(), \ topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \ topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \ gating_output.data_ptr<scalar_t>(), \
correction_bias.data_ptr<scalar_t>(), \ correction_bias.data_ptr<param_t>(), \
num_tokens, \ num_tokens, \
num_expert_group, \ num_expert_group, \
topk_group, \ topk_group, \
...@@ -635,7 +640,6 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu( ...@@ -635,7 +640,6 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
const auto st = hidden_states.scalar_type(); const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st); CHECK_EQ(gating_output.scalar_type(), st);
CHECK_EQ(correction_bias.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0); int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1); int64_t num_experts = gating_output.size(1);
...@@ -644,8 +648,7 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu( ...@@ -644,8 +648,7 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat)); at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt)); at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "biased_grouped_topk_kernel", [&] { CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(st, correction_bias.scalar_type(), "biased_grouped_topk_kernel", [&] {
// NOW only support DSv3 configs
TORCH_CHECK(topk == 8, "Unexpected topk: ", topk); TORCH_CHECK(topk == 8, "Unexpected topk: ", topk);
switch (num_experts) { switch (num_experts) {
case 256: case 256:
......
...@@ -16,6 +16,25 @@ inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, c ...@@ -16,6 +16,25 @@ inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, c
return at::vec::convert_from_float<scalar_t>(a, b); return at::vec::convert_from_float<scalar_t>(a, b);
} }
// allow f16, bf16
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 1>
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const scalar_t* __restrict__ data) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
bVec x_vec = bVec::loadu(data);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x_vec);
return std::make_tuple(x0, x1);
}
// allow f32
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const float* __restrict__ data) {
using fVec = at::vec::Vectorized<float>;
fVec x0 = fVec::loadu(data);
fVec x1 = fVec::loadu(data + fVec::size());
return std::make_tuple(x0, x1);
}
#if defined(CPU_CAPABILITY_AVX512) #if defined(CPU_CAPABILITY_AVX512)
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics // `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
......
...@@ -66,13 +66,15 @@ class TestGroupedTopK(CustomTestCase): ...@@ -66,13 +66,15 @@ class TestGroupedTopK(CustomTestCase):
# DeepSeek V2/V3/R1 uses biased_grouped_top # DeepSeek V2/V3/R1 uses biased_grouped_top
class TestBiasedGroupedTopK(CustomTestCase): class TestBiasedGroupedTopK(CustomTestCase):
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): def _run_single_test(
self, M, E, G, topk, topk_group, renormalize, dtype, bias_dtype
):
torch.manual_seed(1234) torch.manual_seed(1234)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype) hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
correction_bias = torch.randn(E, dtype=dtype) correction_bias = torch.randn(E, dtype=bias_dtype)
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk( ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
hidden_states.float(), hidden_states.float(),
...@@ -106,7 +108,10 @@ class TestBiasedGroupedTopK(CustomTestCase): ...@@ -106,7 +108,10 @@ class TestBiasedGroupedTopK(CustomTestCase):
def test_biased_grouped_topk(self): def test_biased_grouped_topk(self):
for renormalize in [True, False]: for renormalize in [True, False]:
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16) for bias_dtype in [torch.float32, torch.bfloat16]:
self._run_single_test(
122, 256, 8, 8, 2, renormalize, torch.bfloat16, bias_dtype
)
class TestTopK(CustomTestCase): class TestTopK(CustomTestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment