Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b2068609
Commit
b2068609
authored
Aug 06, 2024
by
wangmin6
Browse files
fix deepseek_v2 236b fused_moe_kernel expert_ids_ptr value error
parent
9f9f3796
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
20 deletions
+50
-20
csrc/moe_align_block_size_kernels.cu
csrc/moe_align_block_size_kernels.cu
+50
-20
No files found.
csrc/moe_align_block_size_kernels.cu
View file @
b2068609
...
@@ -9,6 +9,8 @@
...
@@ -9,6 +9,8 @@
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define MAX_SHARED_MEM_SIZE 64 * 1024
namespace
vllm
{
namespace
vllm
{
namespace
{
namespace
{
...
@@ -19,11 +21,12 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
...
@@ -19,11 +21,12 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
}
}
}
// namespace
}
// namespace
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
bool
experts_num_exceed_limit
>
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
*
total_tokens_post_pad
,
int32_t
*
global_tokens_cnts_ptr
,
int32_t
num_experts
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
)
{
int32_t
block_size
,
size_t
numel
)
{
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
...
@@ -31,11 +34,18 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
...
@@ -31,11 +34,18 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
extern
__shared__
int32_t
shared_mem
[];
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
tokens_cnts
=
int32_t
*
tokens_cnts
=
nullptr
;
shared_mem
;
// 2d tensor with shape (num_experts + 1, num_experts)
int32_t
*
cumsum
=
nullptr
;
int32_t
*
cumsum
=
if
(
experts_num_exceed_limit
)
{
shared_mem
+
(
num_experts
+
1
)
*
// 2d tensor with shape (num_experts + 1, num_experts)
num_experts
;
// 1d tensor with shape (num_experts + 1)
tokens_cnts
=
global_tokens_cnts_ptr
;
// 1d tensor with shape (num_experts + 1)
cumsum
=
shared_mem
;
}
else
{
tokens_cnts
=
shared_mem
;
// 2d tensor with shape (num_experts + 1, num_experts)
cumsum
=
shared_mem
+
(
num_experts
+
1
)
*
num_experts
;
// 1d tensor with shape (num_experts + 1)
}
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
...
@@ -115,20 +125,40 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -115,20 +125,40 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
int32_t
shared_mem_normal
=
((
num_experts
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
// tensors
sizeof
(
int32_t
);
const
int32_t
shared_mem
=
((
num_experts
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
const
bool
experts_num_exceed_limit
=
shared_mem_normal
>
MAX_SHARED_MEM_SIZE
;
sizeof
(
int32_t
);
// calc needed amount of shared mem for `cumsum`
// set dynamic shared mem
const
int32_t
shared_mem
=
experts_num_exceed_limit
?
(
num_experts
+
1
)
*
sizeof
(
int32_t
)
:
shared_mem_normal
;
auto
kernel
=
vllm
::
moe_align_block_size_kernel
<
scalar_t
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
if
(
experts_num_exceed_limit
)
{
// set dynamic shared mem
auto
kernel
=
vllm
::
moe_align_block_size_kernel
<
scalar_t
,
true
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
int32_t
tokens_cnts
[(
num_experts
+
1
)
*
num_experts
];
torch
::
Tensor
key_cache_ptrs_tensor
=
torch
::
from_blob
(
tokens_cnts
,
{(
num_experts
+
1
)
*
num_experts
},
torch
::
kInt32
)
.
to
(
topk_ids
.
device
());
kernel
<<<
1
,
num_experts
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
key_cache_ptrs_tensor
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
());
}
else
{
// set dynamic shared mem
auto
kernel
=
vllm
::
moe_align_block_size_kernel
<
scalar_t
,
false
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
(
void
*
)
kernel
,
shared_mem
));
kernel
<<<
1
,
num_experts
,
shared_mem
,
stream
>>>
(
kernel
<<<
1
,
num_experts
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
nullptr
,
num_experts
,
block_size
,
topk_ids
.
numel
());
topk_ids
.
numel
());
}
});
});
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment