Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c7973222
Unverified
Commit
c7973222
authored
Jul 05, 2025
by
Mick
Committed by
GitHub
Jul 04, 2025
Browse files
fix: fix apply_shuffle_mul_sum (#7444)
parent
ef8a29c4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
43 deletions
+67
-43
sgl-kernel/csrc/moe/prepare_moe_input.cu
sgl-kernel/csrc/moe/prepare_moe_input.cu
+67
-43
No files found.
sgl-kernel/csrc/moe/prepare_moe_input.cu
100755 → 100644
View file @
c7973222
...
@@ -2,9 +2,11 @@
...
@@ -2,9 +2,11 @@
#include <cudaTypedefs.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <torch/all.h>
#include <flashinfer/vec_dtypes.cuh>
#include <iostream>
#include <iostream>
#include "cutlass/array.h"
#include "cutlass/array.h"
#include "utils.h"
constexpr
uint64_t
THREADS_PER_EXPERT
=
512
;
constexpr
uint64_t
THREADS_PER_EXPERT
=
512
;
...
@@ -255,37 +257,67 @@ void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2sr
...
@@ -255,37 +257,67 @@ void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2sr
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
apply_shuffle_mul_sum_kernel
(
__global__
void
apply_shuffle_mul_sum_kernel
(
const
scalar_t
*
__restrict__
input_tensor
,
// [m * topk,
row_stride]
const
scalar_t
*
__restrict__
input_tensor
,
// [m * topk,
k] (expert-major layout)
scalar_t
*
__restrict__
output_tensor
,
// [m,
row_stride]
scalar_t
*
__restrict__
output_tensor
,
// [m,
k] (token-major layout)
const
int32_t
*
__restrict__
permutation
,
// [m * topk]
const
int32_t
*
__restrict__
permutation
,
// [m * topk]
(c_map: token-major-idx -> expert-major-idx)
int
m
,
int
m
,
int
topk
,
int
topk
,
int
row_stride
,
int
row_stride
,
const
scalar_t
*
__restrict__
factors
)
// [m * topk]
or nullptr
const
scalar_t
*
__restrict__
factors
)
// [m * topk]
(topk_weights, token-major layout)
{
{
int
i
=
blockIdx
.
x
;
// [0, m * topk)
int
i
=
blockIdx
.
x
;
int
d
=
threadIdx
.
x
;
// [0, row_stride)
if
(
i
>=
m
)
{
return
;
if
(
i
>=
m
||
d
>=
row_stride
)
return
;
}
scalar_t
sum_val
=
0.0
;
for
(
int
j
=
0
;
j
<
topk
;
++
j
)
{
int
index_2d
=
i
*
topk
+
j
;
int
src_row
=
permutation
[
index_2d
];
if
(
src_row
>=
m
)
continue
;
scalar_t
val
=
input_tensor
[
src_row
*
row_stride
+
d
];
scalar_t
factor
=
1.0
;
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
if
(
factors
!=
nullptr
)
{
using
t
=
float
;
factor
=
factors
[
index_2d
];
using
vec_t
=
flashinfer
::
vec_t
<
t
,
vec_size
>
;
int
thread_idx
=
threadIdx
.
x
;
int
stride
=
blockDim
.
x
;
for
(
int
d_vec_idx
=
thread_idx
;
d_vec_idx
<
row_stride
/
vec_size
;
d_vec_idx
+=
stride
)
{
int
d
=
d_vec_idx
*
vec_size
;
vec_t
sum_vec
;
sum_vec
.
fill
(
0.0
f
);
for
(
int
j
=
0
;
j
<
topk
;
++
j
)
{
int
token_major_idx
=
i
*
topk
+
j
;
int
src_row
=
permutation
[
token_major_idx
];
vec_t
val_vec
;
val_vec
.
cast_load
(
input_tensor
+
src_row
*
row_stride
+
d
);
t
factor
=
1.0
;
if
(
factors
!=
nullptr
)
{
factor
=
factors
[
token_major_idx
];
}
#pragma unroll
for
(
int
k
=
0
;
k
<
vec_size
;
++
k
)
{
sum_vec
[
k
]
+=
factor
*
val_vec
[
k
];
}
}
}
sum_vec
.
cast_store
(
output_tensor
+
i
*
row_stride
+
d
);
sum_val
+=
factor
*
val
;
}
}
output_tensor
[
i
*
row_stride
+
d
]
=
sum_val
;
// remainder part
int
remainder_start
=
(
row_stride
/
vec_size
)
*
vec_size
;
for
(
int
d
=
remainder_start
+
thread_idx
;
d
<
row_stride
;
d
+=
stride
)
{
t
sum_val
=
0.0
;
for
(
int
j
=
0
;
j
<
topk
;
++
j
)
{
int
token_major_idx
=
i
*
topk
+
j
;
int
src_row
=
permutation
[
token_major_idx
];
t
val
=
input_tensor
[
src_row
*
row_stride
+
d
];
t
factor
=
1.0
;
if
(
factors
!=
nullptr
)
{
factor
=
factors
[
token_major_idx
];
}
sum_val
+=
factor
*
val
;
}
output_tensor
[
i
*
row_stride
+
d
]
=
sum_val
;
}
}
}
void
get_apply_shuffle_mul_sum_caller
(
void
get_apply_shuffle_mul_sum_caller
(
...
@@ -304,7 +336,11 @@ void get_apply_shuffle_mul_sum_caller(
...
@@ -304,7 +336,11 @@ void get_apply_shuffle_mul_sum_caller(
TORCH_CHECK
(
permutation
.
size
(
0
)
==
m
*
topk
,
"permutation size must match m * topk"
);
TORCH_CHECK
(
permutation
.
size
(
0
)
==
m
*
topk
,
"permutation size must match m * topk"
);
dim3
block
(
std
::
min
(
256
,
row_stride
));
auto
scalar_type
=
output_tensor
.
scalar_type
();
uint32_t
vec_size
=
16
/
sizeof
(
scalar_type
);
auto
blockDim
=
std
::
min
(
row_stride
/
vec_size
,
1024U
);
dim3
block
(
blockDim
);
dim3
grid
(
m
);
// blockIdx.x = j, blockIdx.y = i
dim3
grid
(
m
);
// blockIdx.x = j, blockIdx.y = i
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input_tensor
.
device
().
index
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input_tensor
.
device
().
index
());
...
@@ -317,29 +353,17 @@ void get_apply_shuffle_mul_sum_caller(
...
@@ -317,29 +353,17 @@ void get_apply_shuffle_mul_sum_caller(
factors_ptr
=
factors_opt
->
data_ptr
();
factors_ptr
=
factors_opt
->
data_ptr
();
}
}
if
(
output_tensor
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
output_tensor
.
scalar_type
(),
scalar_t
,
[
&
]
{
const
at
::
Half
*
factor_data
=
static_cast
<
const
at
::
Half
*>
(
factors_ptr
);
apply_shuffle_mul_sum_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
apply_shuffle_mul_sum_kernel
<
at
::
Half
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
scalar_t
*>
(
input_tensor
.
data_ptr
()),
input_tensor
.
data_ptr
<
at
::
Half
>
(),
static_cast
<
scalar_t
*>
(
output_tensor
.
data_ptr
()),
output_tensor
.
data_ptr
<
at
::
Half
>
(),
perm_ptr
,
perm_ptr
,
m
,
m
,
topk
,
topk
,
row_stride
,
row_stride
,
static_cast
<
const
at
::
Half
*>
(
factors_ptr
));
static_cast
<
const
scalar_t
*>
(
factors_ptr
));
}
else
if
(
output_tensor
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
return
true
;
const
c10
::
BFloat16
*
factor_data
=
static_cast
<
const
c10
::
BFloat16
*>
(
factors_ptr
);
});
apply_shuffle_mul_sum_kernel
<
c10
::
BFloat16
><<<
grid
,
block
,
0
,
stream
>>>
(
input_tensor
.
data_ptr
<
c10
::
BFloat16
>
(),
output_tensor
.
data_ptr
<
c10
::
BFloat16
>
(),
perm_ptr
,
m
,
topk
,
row_stride
,
static_cast
<
const
c10
::
BFloat16
*>
(
factors_ptr
));
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output dtype for cast+mul kernel: "
,
output_tensor
.
scalar_type
());
}
}
}
/**
/**
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment