Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5bbefeb0
Commit
5bbefeb0
authored
May 13, 2022
by
XYE
Committed by
binmakeswell
May 17, 2022
Browse files
[NFC] polish moe_cuda_kernel.cu code style (#940)
Co-authored-by:
Xiao Ye
<
xiaoye2@illinois.edu
>
parent
7aa35eae
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
52 deletions
+25
-52
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
+25
-52
No files found.
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
View file @
5bbefeb0
#include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include "block_reduce.h"
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
__device__
void
moe_dpch_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -28,7 +29,6 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
...
@@ -28,7 +29,6 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
__device__
void
moe_dpch_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -51,7 +51,6 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
...
@@ -51,7 +51,6 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_two_fwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
__device__
void
moe_dpch_two_fwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -75,7 +74,6 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
...
@@ -75,7 +74,6 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_two_bwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
__device__
void
moe_dpch_two_bwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -105,7 +103,6 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
...
@@ -105,7 +103,6 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
T
weight
,
__device__
void
moe_cb_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
T
weight
,
const
int
cols
)
{
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -134,7 +131,6 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
...
@@ -134,7 +131,6 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
T
*
tks_row
,
__device__
void
moe_cb_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
T
*
tks_row
,
T
*
weight_grad
,
const
T
weight
,
const
int
cols
)
{
T
*
weight_grad
,
const
T
weight
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -164,15 +160,13 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
...
@@ -164,15 +160,13 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
blockReduce
<
ReduceType
::
kSum
,
1
>
(
&
thread_sum
);
blockReduce
<
ReduceType
::
kSum
,
1
>
(
&
thread_sum
);
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
*
weight_grad
=
static_cast
<
T
>
(
thread_sum
);
*
weight_grad
=
static_cast
<
T
>
(
thread_sum
);
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_two_fwd
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
__device__
void
moe_cb_two_fwd
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
const
T
weight1
,
const
T
weight2
,
const
T
weight1
,
const
T
weight2
,
const
int
cols
)
{
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -204,7 +198,6 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
...
@@ -204,7 +198,6 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T
*
tks_row1
,
T
*
tks_row2
,
T
*
weight_grad1
,
T
*
tks_row1
,
T
*
tks_row2
,
T
*
weight_grad1
,
T
*
weight_grad2
,
const
T
weight1
,
T
*
weight_grad2
,
const
T
weight1
,
const
T
weight2
,
const
int
cols
)
{
const
T
weight2
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -251,7 +244,6 @@ template <typename T, int block_size, int pack_size>
...
@@ -251,7 +244,6 @@ template <typename T, int block_size, int pack_size>
__device__
void
moe_dpch_fwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
__device__
void
moe_dpch_fwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
,
const
int
indicator1
,
const
int
cols
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_dpch_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
moe_dpch_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
cols
);
cols
);
...
@@ -267,7 +259,6 @@ template <typename T, int block_size, int pack_size>
...
@@ -267,7 +259,6 @@ template <typename T, int block_size, int pack_size>
__device__
void
moe_dpch_bwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
__device__
void
moe_dpch_bwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
,
const
int
indicator1
,
const
int
cols
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_dpch_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
moe_dpch_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
cols
);
cols
);
...
@@ -283,7 +274,6 @@ template <typename T, int block_size, int pack_size>
...
@@ -283,7 +274,6 @@ template <typename T, int block_size, int pack_size>
__global__
void
moe_dpch_fwd_kernel
(
T
*
batch_tokens
,
T
*
expert_input
,
__global__
void
moe_dpch_fwd_kernel
(
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
h
)
{
int
*
dest2
,
const
int
h
)
{
int
row
=
blockIdx
.
x
;
int
row
=
blockIdx
.
x
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
moe_dpch_fwd_selector
<
T
,
block_size
,
pack_size
>
(
moe_dpch_fwd_selector
<
T
,
block_size
,
pack_size
>
(
...
@@ -295,7 +285,6 @@ template <typename T, int block_size, int pack_size>
...
@@ -295,7 +285,6 @@ template <typename T, int block_size, int pack_size>
__global__
void
moe_dpch_bwd_kernel
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
__global__
void
moe_dpch_bwd_kernel
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
h
)
{
const
int
h
)
{
int
row
=
blockIdx
.
x
;
int
row
=
blockIdx
.
x
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
moe_dpch_bwd_selector
<
T
,
block_size
,
pack_size
>
(
moe_dpch_bwd_selector
<
T
,
block_size
,
pack_size
>
(
...
@@ -310,7 +299,6 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
...
@@ -310,7 +299,6 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
const
int
cols
,
const
T
weight1
,
const
int
cols
,
const
T
weight1
,
const
T
weight2
,
const
int
indicator1
,
const
T
weight2
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_cb_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
moe_cb_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
weight1
,
weight2
,
cols
);
weight1
,
weight2
,
cols
);
...
@@ -328,7 +316,6 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
...
@@ -328,7 +316,6 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
T
*
wt_grad1
,
T
*
wt_grad2
,
const
T
weight1
,
T
*
wt_grad1
,
T
*
wt_grad2
,
const
T
weight1
,
const
T
weight2
,
const
int
indicator1
,
const
T
weight2
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_cb_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
moe_cb_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
tks_row1
,
tks_row2
,
wt_grad1
,
tks_row1
,
tks_row2
,
wt_grad1
,
...
@@ -348,7 +335,6 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
...
@@ -348,7 +335,6 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
T
*
logits
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
T
*
logits
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
e
,
const
int
c
,
int
*
dest2
,
const
int
e
,
const
int
c
,
const
int
h
)
{
const
int
h
)
{
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
T
*
row_log
=
logits
+
(
row
*
e
);
T
*
row_log
=
logits
+
(
row
*
e
);
...
@@ -363,7 +349,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
...
@@ -363,7 +349,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T
*
logits
,
T
*
logits_grad
,
int
*
mask1
,
T
*
logits
,
T
*
logits_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
e
,
const
int
c
,
const
int
h
)
{
const
int
e
,
const
int
c
,
const
int
h
)
{
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
T
*
row_log
=
logits
+
(
row
*
e
),
*
row_grad
=
logits_grad
+
(
row
*
e
);
T
*
row_log
=
logits
+
(
row
*
e
),
*
row_grad
=
logits_grad
+
(
row
*
e
);
...
@@ -379,7 +364,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
...
@@ -379,7 +364,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
template
<
int
block_size
,
int
pack_size
>
template
<
int
block_size
,
int
pack_size
>
__global__
void
cumsum_kernel
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
__global__
void
cumsum_kernel
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
const
int
e
)
{
const
int
e
)
{
assert
(
s
%
pack_size
==
0
);
assert
(
s
%
pack_size
==
0
);
constexpr
int
bpack_size
=
block_size
*
pack_size
;
constexpr
int
bpack_size
=
block_size
*
pack_size
;
int
tid
=
threadIdx
.
x
,
bid
=
blockIdx
.
x
,
tps
=
tid
*
pack_size
,
last_sum
=
-
1
;
int
tid
=
threadIdx
.
x
,
bid
=
blockIdx
.
x
,
tps
=
tid
*
pack_size
,
last_sum
=
-
1
;
...
@@ -426,8 +410,7 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
...
@@ -426,8 +410,7 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
}
}
__syncthreads
();
__syncthreads
();
if
(
tid
==
0
)
if
(
tid
==
0
)
temp
[
0
]
=
temp
[
block_size
];
temp
[
0
]
=
temp
[
block_size
];
__syncthreads
();
__syncthreads
();
if
(
idx
+
tps
<
s
)
{
if
(
idx
+
tps
<
s
)
{
...
@@ -453,7 +436,6 @@ template <typename T>
...
@@ -453,7 +436,6 @@ template <typename T>
void
moe_dpch_fwd_launch
(
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask1
,
void
moe_dpch_fwd_launch
(
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_dpch_fwd_kernel
<
T
,
32
,
4
>
moe_dpch_fwd_kernel
<
T
,
32
,
4
>
<<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
<<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
...
@@ -474,7 +456,6 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
...
@@ -474,7 +456,6 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
template
<
typename
T
>
template
<
typename
T
>
void
moe_dpch_bwd_launch
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
int
*
mask2
,
void
moe_dpch_bwd_launch
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_dpch_bwd_kernel
<
T
,
32
,
4
>
moe_dpch_bwd_kernel
<
T
,
32
,
4
>
<<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
<<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
...
@@ -496,7 +477,6 @@ template <typename T>
...
@@ -496,7 +477,6 @@ template <typename T>
void
moe_cb_fwd_launch
(
T
*
expert_tokens
,
T
*
combine_tokens
,
T
*
logits
,
void
moe_cb_fwd_launch
(
T
*
expert_tokens
,
T
*
combine_tokens
,
T
*
logits
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_cb_fwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
expert_tokens
,
combine_tokens
,
moe_cb_fwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
...
@@ -524,12 +504,11 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
...
@@ -524,12 +504,11 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T
*
logits_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
T
*
logits_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_cb_bwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
tks
,
moe_cb_bwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
tks
,
logits
,
logits_grad
,
mask1
,
mask2
,
logits
,
logits_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
dest1
,
dest2
,
e
,
c
,
h
);
else
// if (h < 512)
else
// if (h < 512)
moe_cb_bwd_kernel
<
T
,
64
,
4
><<<
s
,
64
>>>
(
tokens_grad
,
expert_grad
,
tks
,
moe_cb_bwd_kernel
<
T
,
64
,
4
><<<
s
,
64
>>>
(
tokens_grad
,
expert_grad
,
tks
,
logits
,
logits_grad
,
mask1
,
mask2
,
logits
,
logits_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
dest1
,
dest2
,
e
,
c
,
h
);
...
@@ -544,7 +523,6 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
...
@@ -544,7 +523,6 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
}
}
void
cumsum_launch
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
const
int
e
)
{
void
cumsum_launch
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
const
int
e
)
{
if
(
s
<=
256
)
if
(
s
<=
256
)
cumsum_kernel
<
256
,
1
><<<
e
,
256
>>>
(
inputs
,
outputs
,
s
,
e
);
cumsum_kernel
<
256
,
1
><<<
e
,
256
>>>
(
inputs
,
outputs
,
s
,
e
);
else
if
(
s
<=
512
)
else
if
(
s
<=
512
)
...
@@ -559,27 +537,26 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
...
@@ -559,27 +537,26 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
// API FUNCTIONS --------------------------------
// API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...)
\
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch
(
TYPE
)
{
\
switch
(
TYPE
)
{
\
case
at
::
ScalarType
::
Float
:
{
\
case
at
::
ScalarType
::
Float
:
{
\
using
scalar_t
=
float
;
\
using
scalar_t
=
float
;
\
__VA_ARGS__
;
\
__VA_ARGS__
;
\
break
;
\
break
;
\
}
\
}
\
case
at
::
ScalarType
::
Half
:
{
\
case
at
::
ScalarType
::
Half
:
{
\
using
scalar_t
=
at
::
Half
;
\
using
scalar_t
=
at
::
Half
;
\
__VA_ARGS__
;
\
__VA_ARGS__
;
\
break
;
\
break
;
\
}
\
}
\
default:
\
default:
\
AT_ERROR
(
#
NAME
,
" not implemented yet for specific data type."
);
\
AT_ERROR
(
#
NAME
,
" not implemented yet for specific data type."
);
\
}
}
torch
::
Tensor
moe_dispatch_cuda_forward
(
int
s
,
int
ec
,
int
h
,
torch
::
Tensor
moe_dispatch_cuda_forward
(
int
s
,
int
ec
,
int
h
,
torch
::
Tensor
batch_tokens
,
torch
::
Tensor
batch_tokens
,
torch
::
Tensor
mask
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
auto
res
=
torch
::
zeros
(
auto
res
=
torch
::
zeros
(
{
ec
,
h
},
{
ec
,
h
},
...
@@ -601,7 +578,6 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
...
@@ -601,7 +578,6 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch
::
Tensor
expert_grad
,
torch
::
Tensor
expert_grad
,
torch
::
Tensor
mask
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
auto
res
=
torch
::
zeros
(
auto
res
=
torch
::
zeros
(
{
s
,
h
},
torch
::
dtype
(
expert_grad
.
dtype
()).
device
(
expert_grad
.
device
()));
{
s
,
h
},
torch
::
dtype
(
expert_grad
.
dtype
()).
device
(
expert_grad
.
device
()));
...
@@ -622,7 +598,6 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
...
@@ -622,7 +598,6 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
...
@@ -643,11 +618,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
...
@@ -643,11 +618,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
return
res
;
return
res
;
}
}
std
::
vector
<
torch
::
Tensor
>
std
::
vector
<
torch
::
Tensor
>
moe_combine_cuda_backward
(
moe_combine_cuda_backward
(
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
tokens_grad
,
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
tokens_grad
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
assert
(
tokens_grad
.
dtype
()
==
expert_tokens
.
dtype
());
assert
(
tokens_grad
.
dtype
()
==
expert_tokens
.
dtype
());
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
...
@@ -673,7 +647,6 @@ moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
...
@@ -673,7 +647,6 @@ moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
}
}
torch
::
Tensor
cumsum_sub_one_in_dim0
(
torch
::
Tensor
mask
)
{
torch
::
Tensor
cumsum_sub_one_in_dim0
(
torch
::
Tensor
mask
)
{
assert
(
mask
.
dim
()
==
2
);
assert
(
mask
.
dim
()
==
2
);
assert
(
mask
.
dtype
()
==
torch
::
kInt32
);
assert
(
mask
.
dtype
()
==
torch
::
kInt32
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment