Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6582aedc
".github/vscode:/vscode.git/clone" did not exist on "d41a9f12c6300224aa5671c35fc6cd3f786d54b5"
Commit
6582aedc
authored
Mar 31, 2022
by
binmakeswell
Browse files
fix format (#583)
parent
f08fc17f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
575 additions
and
591 deletions
+575
-591
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
+575
-591
No files found.
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
View file @
6582aedc
#include <torch/extension.h>
#include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cub/cub.cuh>
#include <torch/extension.h>
#include "block_reduce.h"
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
__device__
void
moe_dpch_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
BlockStore
(
ts_store
).
Store
(
dst_row
+
idx
,
pack
);
BlockStore
(
ts_store
).
Store
(
dst_row
+
idx
,
pack
);
}
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
__device__
void
moe_dpch_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
pack
);
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
pack
);
BlockStore
(
ts_store
).
Store
(
src_row
+
idx
,
pack
);
BlockStore
(
ts_store
).
Store
(
src_row
+
idx
,
pack
);
}
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_two_fwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
__device__
void
moe_dpch_two_fwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
BlockStore
(
ts_store
).
Store
(
dst_row1
+
idx
,
pack
);
BlockStore
(
ts_store
).
Store
(
dst_row1
+
idx
,
pack
);
...
@@ -68,18 +72,19 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int
...
@@ -68,18 +72,19 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int
}
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_two_bwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
__device__
void
moe_dpch_two_bwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
int
tps
=
threadIdx
.
x
*
pack_size
;
...
@@ -88,7 +93,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int
...
@@ -88,7 +93,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int
BlockLoad
(
ts_load
).
Load
(
dst_row1
+
idx
,
pack1
);
BlockLoad
(
ts_load
).
Load
(
dst_row1
+
idx
,
pack1
);
BlockLoad
(
ts_load
).
Load
(
dst_row2
+
idx
,
pack2
);
BlockLoad
(
ts_load
).
Load
(
dst_row2
+
idx
,
pack2
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
pack1
[
i
]
+=
pack2
[
i
];
pack1
[
i
]
+=
pack2
[
i
];
}
}
...
@@ -97,27 +102,27 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int
...
@@ -97,27 +102,27 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int
}
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_one_fwd
(
__device__
void
moe_cb_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
T
weight
,
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
const
T
weight
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
pack
[
i
]
*=
weight
;
pack
[
i
]
*=
weight
;
}
}
...
@@ -126,20 +131,19 @@ __device__ void moe_cb_one_fwd(
...
@@ -126,20 +131,19 @@ __device__ void moe_cb_one_fwd(
}
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_one_bwd
(
__device__
void
moe_cb_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
T
*
tks_row
,
T
*
src_row
,
T
*
dst_row
,
T
*
tks_row
,
T
*
weight_grad
,
T
*
weight_grad
,
const
T
weight
,
const
int
cols
)
{
const
T
weight
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
int
tps
=
threadIdx
.
x
*
pack_size
;
...
@@ -149,7 +153,7 @@ __device__ void moe_cb_one_bwd(
...
@@ -149,7 +153,7 @@ __device__ void moe_cb_one_bwd(
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
grad
);
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
grad
);
BlockLoad
(
ts_load
).
Load
(
tks_row
+
idx
,
tokens
);
BlockLoad
(
ts_load
).
Load
(
tks_row
+
idx
,
tokens
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
thread_sum
+=
grad
[
i
]
*
tokens
[
i
];
thread_sum
+=
grad
[
i
]
*
tokens
[
i
];
grad
[
i
]
*=
weight
;
grad
[
i
]
*=
weight
;
...
@@ -164,20 +168,20 @@ __device__ void moe_cb_one_bwd(
...
@@ -164,20 +168,20 @@ __device__ void moe_cb_one_bwd(
*
weight_grad
=
static_cast
<
T
>
(
thread_sum
);
*
weight_grad
=
static_cast
<
T
>
(
thread_sum
);
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_two_fwd
(
__device__
void
moe_cb_two_fwd
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
const
T
weight1
,
const
T
weight2
,
const
T
weight1
,
const
T
weight2
,
const
int
cols
)
{
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
int
tps
=
threadIdx
.
x
*
pack_size
;
...
@@ -186,7 +190,7 @@ __device__ void moe_cb_two_fwd(
...
@@ -186,7 +190,7 @@ __device__ void moe_cb_two_fwd(
BlockLoad
(
ts_load
).
Load
(
src_row1
+
idx
,
pack1
);
BlockLoad
(
ts_load
).
Load
(
src_row1
+
idx
,
pack1
);
BlockLoad
(
ts_load
).
Load
(
src_row2
+
idx
,
pack2
);
BlockLoad
(
ts_load
).
Load
(
src_row2
+
idx
,
pack2
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
pack1
[
i
]
=
pack1
[
i
]
*
weight1
+
pack2
[
i
]
*
weight2
;
pack1
[
i
]
=
pack1
[
i
]
*
weight1
+
pack2
[
i
]
*
weight2
;
}
}
...
@@ -195,33 +199,33 @@ __device__ void moe_cb_two_fwd(
...
@@ -195,33 +199,33 @@ __device__ void moe_cb_two_fwd(
}
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_two_bwd
(
__device__
void
moe_cb_two_bwd
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
T
*
src
_row1
,
T
*
src
_row2
,
T
*
dst_row
,
T
*
tks
_row1
,
T
*
tks
_row2
,
T
*
weight_grad1
,
T
*
tks_row1
,
T
*
tks_row2
,
T
*
weight_grad
1
,
T
*
weight
_grad2
,
T
*
weight_grad
2
,
const
T
weight
1
,
const
T
weight1
,
const
T
weight2
,
const
int
cols
)
{
const
T
weight2
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
int
tps
=
threadIdx
.
x
*
pack_size
;
T
grad
[
pack_size
],
tokens1
[
pack_size
],
tokens2
[
pack_size
],
T
grad
[
pack_size
],
tokens1
[
pack_size
],
tokens2
[
pack_size
],
sgrad1
[
pack_size
],
sgrad1
[
pack_size
],
sgrad2
[
pack_size
];
sgrad2
[
pack_size
];
float
thread_sum
[
2
]
=
{
0
,
0
};
float
thread_sum
[
2
]
=
{
0
,
0
};
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
grad
);
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
grad
);
BlockLoad
(
ts_load
).
Load
(
tks_row1
+
idx
,
tokens1
);
BlockLoad
(
ts_load
).
Load
(
tks_row1
+
idx
,
tokens1
);
BlockLoad
(
ts_load
).
Load
(
tks_row2
+
idx
,
tokens2
);
BlockLoad
(
ts_load
).
Load
(
tks_row2
+
idx
,
tokens2
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
thread_sum
[
0
]
+=
grad
[
i
]
*
tokens1
[
i
];
thread_sum
[
0
]
+=
grad
[
i
]
*
tokens1
[
i
];
thread_sum
[
1
]
+=
grad
[
i
]
*
tokens2
[
i
];
thread_sum
[
1
]
+=
grad
[
i
]
*
tokens2
[
i
];
...
@@ -239,142 +243,125 @@ __device__ void moe_cb_two_bwd(
...
@@ -239,142 +243,125 @@ __device__ void moe_cb_two_bwd(
*
weight_grad1
=
static_cast
<
T
>
(
thread_sum
[
0
]);
*
weight_grad1
=
static_cast
<
T
>
(
thread_sum
[
0
]);
else
if
(
threadIdx
.
x
==
1
)
else
if
(
threadIdx
.
x
==
1
)
*
weight_grad2
=
static_cast
<
T
>
(
thread_sum
[
1
]);
*
weight_grad2
=
static_cast
<
T
>
(
thread_sum
[
1
]);
}
}
// DISPATCH KERNELS --------------------------------
// DISPATCH KERNELS --------------------------------
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_fwd_selector
(
__device__
void
moe_dpch_fwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
,
const
int
cols
,
const
int
indicator1
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_dpch_two_fwd
<
T
,
block_size
,
pack_size
>
(
moe_dpch_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
src_row
,
dst_row1
,
dst_row2
,
cols
);
cols
);
else
if
(
indicator1
!=
0
)
else
if
(
indicator1
!=
0
)
moe_dpch_one_fwd
<
T
,
block_size
,
pack_size
>
(
moe_dpch_one_fwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
cols
);
src_row
,
dst_row1
,
cols
);
else
if
(
indicator2
!=
0
)
else
if
(
indicator2
!=
0
)
moe_dpch_one_fwd
<
T
,
block_size
,
pack_size
>
(
moe_dpch_one_fwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row2
,
cols
);
src_row
,
dst_row2
,
cols
);
else
else
return
;
return
;
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_bwd_selector
(
__device__
void
moe_dpch_bwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
,
const
int
cols
,
const
int
indicator1
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_dpch_two_bwd
<
T
,
block_size
,
pack_size
>
(
moe_dpch_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
src_row
,
dst_row1
,
dst_row2
,
cols
);
cols
);
else
if
(
indicator1
!=
0
)
else
if
(
indicator1
!=
0
)
moe_dpch_one_bwd
<
T
,
block_size
,
pack_size
>
(
moe_dpch_one_bwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
cols
);
src_row
,
dst_row1
,
cols
);
else
if
(
indicator2
!=
0
)
else
if
(
indicator2
!=
0
)
moe_dpch_one_bwd
<
T
,
block_size
,
pack_size
>
(
moe_dpch_one_bwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row2
,
cols
);
src_row
,
dst_row2
,
cols
);
else
else
return
;
return
;
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__global__
void
moe_dpch_fwd_kernel
(
__global__
void
moe_dpch_fwd_kernel
(
T
*
batch_tokens
,
T
*
expert_input
,
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
mask1
,
int
*
mask2
,
int
*
dest2
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
h
)
{
int
row
=
blockIdx
.
x
;
int
row
=
blockIdx
.
x
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
moe_dpch_fwd_selector
<
T
,
block_size
,
pack_size
>
(
moe_dpch_fwd_selector
<
T
,
block_size
,
pack_size
>
(
batch_tokens
+
(
row
*
h
),
batch_tokens
+
(
row
*
h
),
expert_input
+
(
dest1
[
row
]
*
h
),
expert_input
+
(
dest1
[
row
]
*
h
),
expert_input
+
(
dest2
[
row
]
*
h
),
expert_input
+
(
dest2
[
row
]
*
h
),
h
,
mask1
[
row
],
indicator2
);
h
,
mask1
[
row
],
indicator2
);
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__global__
void
moe_dpch_bwd_kernel
(
__global__
void
moe_dpch_bwd_kernel
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask1
,
int
*
mask2
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
h
)
{
int
row
=
blockIdx
.
x
;
int
row
=
blockIdx
.
x
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
moe_dpch_bwd_selector
<
T
,
block_size
,
pack_size
>
(
moe_dpch_bwd_selector
<
T
,
block_size
,
pack_size
>
(
tokens_grad
+
(
row
*
h
),
tokens_grad
+
(
row
*
h
),
expert_grad
+
(
dest1
[
row
]
*
h
),
expert_grad
+
(
dest1
[
row
]
*
h
),
expert_grad
+
(
dest2
[
row
]
*
h
),
expert_grad
+
(
dest2
[
row
]
*
h
),
h
,
mask1
[
row
],
indicator2
);
h
,
mask1
[
row
],
indicator2
);
}
}
// COMBINE KERNELS --------------------------------
// COMBINE KERNELS --------------------------------
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_fwd_selector
(
__device__
void
moe_cb_fwd_selector
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
const
int
cols
,
const
int
cols
,
const
T
weight1
,
const
T
weight
1
,
const
T
weight2
,
const
T
weight
2
,
const
int
indicator1
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_cb_two_fwd
<
T
,
block_size
,
pack_size
>
(
moe_cb_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
src_row1
,
src_row2
,
dst_row
,
weight1
,
weight2
,
cols
);
weight1
,
weight2
,
cols
);
else
if
(
indicator1
!=
0
)
else
if
(
indicator1
!=
0
)
moe_cb_one_fwd
<
T
,
block_size
,
pack_size
>
(
moe_cb_one_fwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
dst_row
,
weight1
,
cols
);
src_row1
,
dst_row
,
weight1
,
cols
);
else
if
(
indicator2
!=
0
)
else
if
(
indicator2
!=
0
)
moe_cb_one_fwd
<
T
,
block_size
,
pack_size
>
(
moe_cb_one_fwd
<
T
,
block_size
,
pack_size
>
(
src_row2
,
dst_row
,
weight2
,
cols
);
src_row2
,
dst_row
,
weight2
,
cols
);
else
else
return
;
return
;
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_bwd_selector
(
__device__
void
moe_cb_bwd_selector
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
T
*
src_row1
,
T
*
src
_row
2
,
T
*
dst
_row
,
const
int
cols
,
const
int
cols
,
T
*
tks
_row
1
,
T
*
tks
_row
2
,
T
*
tks_row1
,
T
*
tks_row2
,
T
*
wt_grad1
,
T
*
wt_grad2
,
T
*
wt_grad1
,
T
*
wt_grad2
,
const
T
weight1
,
const
T
weight
1
,
const
T
weight2
,
const
T
weight
2
,
const
int
indicator1
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_cb_two_bwd
<
T
,
block_size
,
pack_size
>
(
moe_cb_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
src_row1
,
src_row2
,
dst_row
,
tks_row1
,
tks_row2
,
wt_grad1
,
tks_row1
,
tks_row2
,
wt_grad1
,
wt_grad2
,
wt_grad2
,
weight1
,
weight2
,
cols
);
weight1
,
weight2
,
cols
);
else
if
(
indicator1
!=
0
)
else
if
(
indicator1
!=
0
)
moe_cb_one_bwd
<
T
,
block_size
,
pack_size
>
(
moe_cb_one_bwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
dst_row
,
tks_row1
,
src_row1
,
dst_row
,
tks_row1
,
wt_grad1
,
weight1
,
cols
);
wt_grad1
,
weight1
,
cols
);
else
if
(
indicator2
!=
0
)
else
if
(
indicator2
!=
0
)
moe_cb_one_bwd
<
T
,
block_size
,
pack_size
>
(
moe_cb_one_bwd
<
T
,
block_size
,
pack_size
>
(
src_row2
,
dst_row
,
tks_row2
,
src_row2
,
dst_row
,
tks_row2
,
wt_grad2
,
weight2
,
cols
);
wt_grad2
,
weight2
,
cols
);
else
else
return
;
return
;
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__global__
void
moe_cb_fwd_kernel
(
T
*
expert_tokens
,
T
*
combine_tokens
,
__global__
void
moe_cb_fwd_kernel
(
T
*
logits
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
T
*
expert_tokens
,
T
*
combine_tokens
,
T
*
logits
,
int
*
dest2
,
const
int
e
,
const
int
c
,
int
*
mask1
,
int
*
mask2
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
e
,
const
int
c
,
const
int
h
)
{
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
T
*
row_log
=
logits
+
(
row
*
e
);
T
*
row_log
=
logits
+
(
row
*
e
);
moe_cb_fwd_selector
<
T
,
block_size
,
pack_size
>
(
moe_cb_fwd_selector
<
T
,
block_size
,
pack_size
>
(
expert_tokens
+
(
dest1
[
row
]
*
h
),
expert_tokens
+
(
dest2
[
row
]
*
h
),
expert_tokens
+
(
dest1
[
row
]
*
h
),
expert_tokens
+
(
dest2
[
row
]
*
h
),
combine_tokens
+
(
row
*
h
),
h
,
combine_tokens
+
(
row
*
h
),
h
,
row_log
[
eid1
],
row_log
[
eid2
],
mask1
[
row
],
row_log
[
eid1
],
row_log
[
eid2
],
indicator2
);
mask1
[
row
],
indicator2
);
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__global__
void
moe_cb_bwd_kernel
(
__global__
void
moe_cb_bwd_kernel
(
T
*
tokens_grad
,
T
*
expert_grad
,
T
*
tks
,
T
*
tokens_grad
,
T
*
expert_grad
,
T
*
tks
,
T
*
logits
,
T
*
logits_grad
,
int
*
mask1
,
T
*
logits
,
T
*
logits_grad
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
e
,
const
int
c
,
const
int
h
)
{
const
int
e
,
const
int
c
,
const
int
h
)
{
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
...
@@ -382,36 +369,33 @@ __global__ void moe_cb_bwd_kernel(
...
@@ -382,36 +369,33 @@ __global__ void moe_cb_bwd_kernel(
T
*
row_log
=
logits
+
(
row
*
e
),
*
row_grad
=
logits_grad
+
(
row
*
e
);
T
*
row_log
=
logits
+
(
row
*
e
),
*
row_grad
=
logits_grad
+
(
row
*
e
);
moe_cb_bwd_selector
<
T
,
block_size
,
pack_size
>
(
moe_cb_bwd_selector
<
T
,
block_size
,
pack_size
>
(
expert_grad
+
(
dest1
[
row
]
*
h
),
expert_grad
+
(
dest2
[
row
]
*
h
),
expert_grad
+
(
dest1
[
row
]
*
h
),
expert_grad
+
(
dest2
[
row
]
*
h
),
tokens_grad
+
(
row
*
h
),
h
,
tokens_grad
+
(
row
*
h
),
h
,
tks
+
(
dest1
[
row
]
*
h
),
tks
+
(
dest1
[
row
]
*
h
),
tks
+
(
dest2
[
row
]
*
h
),
tks
+
(
dest2
[
row
]
*
h
),
row_grad
+
eid1
,
row_grad
+
eid2
,
row_log
[
eid1
],
row_grad
+
eid1
,
row_grad
+
eid2
,
row_log
[
eid2
],
mask1
[
row
],
indicator2
);
row_log
[
eid1
],
row_log
[
eid2
],
mask1
[
row
],
indicator2
);
}
}
//CUMSUM KERNEL --------------------------------
//
CUMSUM KERNEL --------------------------------
template
<
int
block_size
,
int
pack_size
>
template
<
int
block_size
,
int
pack_size
>
__global__
void
cumsum_kernel
(
__global__
void
cumsum_kernel
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
int
*
inputs
,
int
*
outputs
,
const
int
e
)
{
const
int
s
,
const
int
e
)
{
assert
(
s
%
pack_size
==
0
);
assert
(
s
%
pack_size
==
0
);
constexpr
int
bpack_size
=
block_size
*
pack_size
;
constexpr
int
bpack_size
=
block_size
*
pack_size
;
int
tid
=
threadIdx
.
x
,
bid
=
blockIdx
.
x
,
int
tid
=
threadIdx
.
x
,
bid
=
blockIdx
.
x
,
tps
=
tid
*
pack_size
,
last_sum
=
-
1
;
tps
=
tid
*
pack_size
,
last_sum
=
-
1
;
__shared__
int
temp
[
block_size
+
1
]
;
__shared__
int
temp
[
block_size
+
1
];
int
pack
[
pack_size
];
int
pack
[
pack_size
];
for
(
int
idx
=
0
;
idx
<
s
;
idx
+=
bpack_size
)
{
for
(
int
idx
=
0
;
idx
<
s
;
idx
+=
bpack_size
)
{
int
offset
=
1
;
int
offset
=
1
;
if
(
idx
+
tps
<
s
)
{
if
(
idx
+
tps
<
s
)
{
temp
[
tid
]
=
inputs
[
tps
*
e
+
bid
];
temp
[
tid
]
=
inputs
[
tps
*
e
+
bid
];
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
pack_size
;
++
i
)
{
for
(
int
i
=
1
;
i
<
pack_size
;
++
i
)
{
pack
[
i
]
=
inputs
[(
tps
+
i
)
*
e
+
bid
];
pack
[
i
]
=
inputs
[(
tps
+
i
)
*
e
+
bid
];
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
pack_size
;
++
i
)
{
for
(
int
i
=
1
;
i
<
pack_size
;
++
i
)
{
temp
[
tid
]
+=
pack
[
i
];
temp
[
tid
]
+=
pack
[
i
];
}
}
...
@@ -435,8 +419,7 @@ __global__ void cumsum_kernel(
...
@@ -435,8 +419,7 @@ __global__ void cumsum_kernel(
offset
>>=
1
;
offset
>>=
1
;
__syncthreads
();
__syncthreads
();
if
(
tid
<
i
)
{
if
(
tid
<
i
)
{
int
j
=
offset
*
(
2
*
tid
+
1
)
-
1
,
int
j
=
offset
*
(
2
*
tid
+
1
)
-
1
,
k
=
j
+
offset
,
ts
=
temp
[
j
];
k
=
j
+
offset
,
ts
=
temp
[
j
];
temp
[
j
]
=
temp
[
k
];
temp
[
j
]
=
temp
[
k
];
temp
[
k
]
+=
ts
;
temp
[
k
]
+=
ts
;
}
}
...
@@ -449,7 +432,7 @@ __global__ void cumsum_kernel(
...
@@ -449,7 +432,7 @@ __global__ void cumsum_kernel(
if
(
idx
+
tps
<
s
)
{
if
(
idx
+
tps
<
s
)
{
temp
[
tid
+
1
]
+=
last_sum
;
temp
[
tid
+
1
]
+=
last_sum
;
#pragma unroll
#pragma unroll
for
(
int
i
=
pack_size
-
1
;
i
>
0
;
--
i
)
{
for
(
int
i
=
pack_size
-
1
;
i
>
0
;
--
i
)
{
outputs
[(
tps
+
i
)
*
e
+
bid
]
=
temp
[
tid
+
1
];
outputs
[(
tps
+
i
)
*
e
+
bid
]
=
temp
[
tid
+
1
];
temp
[
tid
+
1
]
-=
pack
[
i
];
temp
[
tid
+
1
]
-=
pack
[
i
];
...
@@ -464,95 +447,103 @@ __global__ void cumsum_kernel(
...
@@ -464,95 +447,103 @@ __global__ void cumsum_kernel(
}
}
}
}
//LAUNCH FUNCTIONS --------------------------------
//
LAUNCH FUNCTIONS --------------------------------
template
<
typename
T
>
template
<
typename
T
>
void
moe_dpch_fwd_launch
(
void
moe_dpch_fwd_launch
(
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask1
,
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
int
*
mask1
,
int
*
mask2
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_dpch_fwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_fwd_kernel
<
T
,
32
,
4
>
<<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
512
)
else
if
(
h
<
512
)
moe_dpch_fwd_kernel
<
T
,
32
,
8
><<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_fwd_kernel
<
T
,
32
,
8
>
<<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
1024
)
else
if
(
h
<
1024
)
moe_dpch_fwd_kernel
<
T
,
32
,
16
><<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_fwd_kernel
<
T
,
32
,
16
>
<<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
2048
)
else
if
(
h
<
2048
)
moe_dpch_fwd_kernel
<
T
,
64
,
16
><<<
s
,
64
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_fwd_kernel
<
T
,
64
,
16
>
<<<
s
,
64
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
else
moe_dpch_fwd_kernel
<
T
,
128
,
16
><<<
s
,
128
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_fwd_kernel
<
T
,
128
,
16
>
<<<
s
,
128
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
moe_dpch_bwd_launch
(
void
moe_dpch_bwd_launch
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
int
*
mask2
,
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_dpch_bwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_bwd_kernel
<
T
,
32
,
4
>
<<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
512
)
else
if
(
h
<
512
)
moe_dpch_bwd_kernel
<
T
,
32
,
8
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_bwd_kernel
<
T
,
32
,
8
>
<<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
1024
)
else
if
(
h
<
1024
)
moe_dpch_bwd_kernel
<
T
,
32
,
16
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_bwd_kernel
<
T
,
32
,
16
>
<<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
2048
)
else
if
(
h
<
2048
)
moe_dpch_bwd_kernel
<
T
,
64
,
16
><<<
s
,
64
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_bwd_kernel
<
T
,
64
,
16
>
<<<
s
,
64
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
else
moe_dpch_bwd_kernel
<
T
,
128
,
16
><<<
s
,
128
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_bwd_kernel
<
T
,
128
,
16
>
<<<
s
,
128
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
moe_cb_fwd_launch
(
void
moe_cb_fwd_launch
(
T
*
expert_tokens
,
T
*
combine_tokens
,
T
*
logits
,
T
*
expert_tokens
,
T
*
combine_tokens
,
T
*
logits
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_cb_fwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
moe_cb_fwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
expert_tokens
,
combine_tokens
,
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
else
if
(
h
<
512
)
else
if
(
h
<
512
)
moe_cb_fwd_kernel
<
T
,
32
,
8
><<<
s
,
32
>>>
moe_cb_fwd_kernel
<
T
,
32
,
8
><<<
s
,
32
>>>
(
expert_tokens
,
combine_tokens
,
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
else
if
(
h
<
1024
)
else
if
(
h
<
1024
)
moe_cb_fwd_kernel
<
T
,
32
,
16
><<<
s
,
32
>>>
moe_cb_fwd_kernel
<
T
,
32
,
16
><<<
s
,
32
>>>
(
expert_tokens
,
combine_tokens
,
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
else
if
(
h
<
2048
)
else
if
(
h
<
2048
)
moe_cb_fwd_kernel
<
T
,
64
,
16
><<<
s
,
64
>>>
moe_cb_fwd_kernel
<
T
,
64
,
16
><<<
s
,
64
>>>
(
expert_tokens
,
combine_tokens
,
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
else
else
moe_cb_fwd_kernel
<
T
,
128
,
16
><<<
s
,
128
>>>
moe_cb_fwd_kernel
<
T
,
128
,
16
><<<
s
,
128
>>>
(
expert_tokens
,
combine_tokens
,
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
moe_cb_bwd_launch
(
void
moe_cb_bwd_launch
(
T
*
tokens_grad
,
T
*
expert_grad
,
T
*
tks
,
T
*
logits
,
T
*
tokens_grad
,
T
*
expert_grad
,
T
*
tks
,
T
*
logits_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
T
*
logits
,
T
*
logits_grad
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
int
*
mask1
,
int
*
mask2
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_cb_bwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
moe_cb_bwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
tks
,
(
tokens_grad
,
expert_grad
,
tks
,
logits
,
logits_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
logits
,
logits_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
else
// if (h < 512)
else
// if (h < 512)
moe_cb_bwd_kernel
<
T
,
64
,
4
><<<
s
,
64
>>>
moe_cb_bwd_kernel
<
T
,
64
,
4
><<<
s
,
64
>>>
(
tokens_grad
,
expert_grad
,
tks
,
(
tokens_grad
,
expert_grad
,
tks
,
logits
,
logits_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
logits
,
logits_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
// else if (h < 1024)
// else if (h < 1024)
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// dest1, dest2, e, c, h);
// else
// else
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// dest1, dest2, e, c, h);
}
}
void
cumsum_launch
(
void
cumsum_launch
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
const
int
e
)
{
int
*
inputs
,
int
*
outputs
,
const
int
s
,
const
int
e
)
{
if
(
s
<=
256
)
if
(
s
<=
256
)
cumsum_kernel
<
256
,
1
><<<
e
,
256
>>>
(
inputs
,
outputs
,
s
,
e
);
cumsum_kernel
<
256
,
1
><<<
e
,
256
>>>
(
inputs
,
outputs
,
s
,
e
);
...
@@ -569,32 +560,29 @@ void cumsum_launch(
...
@@ -569,32 +560,29 @@ void cumsum_launch(
// API FUNCTIONS --------------------------------
// API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch
(
TYPE
)
\
switch
(
TYPE
)
{
\
{
\
case
at
::
ScalarType
::
Float
:
{
\
case
at
::
ScalarType
::
Float
:
\
{
\
using
scalar_t
=
float
;
\
using
scalar_t
=
float
;
\
__VA_ARGS__
;
\
__VA_ARGS__
;
\
break
;
\
break
;
\
}
\
}
\
case
at
::
ScalarType
::
Half
:
\
case
at
::
ScalarType
::
Half
:
{
\
{
\
using
scalar_t
=
at
::
Half
;
\
using
scalar_t
=
at
::
Half
;
\
__VA_ARGS__
;
\
__VA_ARGS__
;
\
break
;
\
break
;
\
}
\
}
\
default:
\
default:
\
AT_ERROR
(
#
NAME
,
" not implemented yet for specific data type."
);
\
AT_ERROR
(
#
NAME
,
" not implemented yet for specific data type."
);
\
}
}
torch
::
Tensor
moe_dispatch_cuda_forward
(
torch
::
Tensor
moe_dispatch_cuda_forward
(
int
s
,
int
ec
,
int
h
,
int
s
,
int
ec
,
int
h
,
torch
::
Tensor
batch_tokens
,
torch
::
Tensor
batch_tokens
,
torch
::
Tensor
mask
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
auto
res
=
torch
::
zeros
({
ec
,
h
},
auto
res
=
torch
::
zeros
(
{
ec
,
h
},
torch
::
dtype
(
batch_tokens
.
dtype
()).
device
(
batch_tokens
.
device
()));
torch
::
dtype
(
batch_tokens
.
dtype
()).
device
(
batch_tokens
.
device
()));
auto
k
=
mask
.
size
(
0
);
auto
k
=
mask
.
size
(
0
);
...
@@ -603,22 +591,20 @@ torch::Tensor moe_dispatch_cuda_forward(
...
@@ -603,22 +591,20 @@ torch::Tensor moe_dispatch_cuda_forward(
moe_dpch_fwd_launch
<
scalar_t
>
(
moe_dpch_fwd_launch
<
scalar_t
>
(
batch_tokens
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
batch_tokens
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
s
,
h
)
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
s
,
h
));
);
return
res
;
return
res
;
}
}
torch
::
Tensor
moe_dispatch_cuda_backward
(
torch
::
Tensor
moe_dispatch_cuda_backward
(
int
s
,
int
ec
,
int
h
,
int
s
,
int
ec
,
int
h
,
torch
::
Tensor
expert_grad
,
torch
::
Tensor
expert_grad
,
torch
::
Tensor
mask
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
auto
res
=
torch
::
zeros
(
{
s
,
h
},
auto
res
=
torch
::
zeros
(
torch
::
dtype
(
expert_grad
.
dtype
()).
device
(
expert_grad
.
device
()));
{
s
,
h
},
torch
::
dtype
(
expert_grad
.
dtype
()).
device
(
expert_grad
.
device
()));
auto
k
=
mask
.
size
(
0
);
auto
k
=
mask
.
size
(
0
);
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
(
...
@@ -626,65 +612,62 @@ torch::Tensor moe_dispatch_cuda_backward(
...
@@ -626,65 +612,62 @@ torch::Tensor moe_dispatch_cuda_backward(
moe_dpch_bwd_launch
<
scalar_t
>
(
moe_dpch_bwd_launch
<
scalar_t
>
(
res
.
data
<
scalar_t
>
(),
expert_grad
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
expert_grad
.
data
<
scalar_t
>
(),
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
s
,
h
)
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
s
,
h
));
);
return
res
;
return
res
;
}
}
torch
::
Tensor
moe_combine_cuda_forward
(
torch
::
Tensor
moe_combine_cuda_forward
(
int
s
,
int
e
,
int
c
,
int
h
,
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
auto
res
=
torch
::
zeros
({
s
,
h
},
auto
res
=
torch
::
zeros
(
{
s
,
h
},
torch
::
dtype
(
expert_tokens
.
dtype
()).
device
(
expert_tokens
.
device
()));
torch
::
dtype
(
expert_tokens
.
dtype
()).
device
(
expert_tokens
.
device
()));
auto
k
=
mask
.
size
(
0
);
auto
k
=
mask
.
size
(
0
);
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
(
expert_tokens
.
scalar_type
(),
"moe combine forward"
,
expert_tokens
.
scalar_type
(),
"moe combine forward"
,
moe_cb_fwd_launch
<
scalar_t
>
(
moe_cb_fwd_launch
<
scalar_t
>
(
expert_tokens
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
logits
.
data
<
scalar_t
>
(),
expert_tokens
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
logits
.
data
<
scalar_t
>
(),
mask
[
0
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
()
,
dest_idx
[
0
].
data
<
int
>
(),
s
,
e
,
c
,
h
)
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
s
,
e
,
c
,
);
h
)
);
return
res
;
return
res
;
}
}
std
::
vector
<
torch
::
Tensor
>
moe_combine_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
int
s
,
int
e
,
int
c
,
int
h
,
moe_combine_cuda_backward
(
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
tokens_grad
,
torch
::
Tensor
tokens_grad
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
assert
(
tokens_grad
.
dtype
()
==
expert_tokens
.
dtype
());
assert
(
tokens_grad
.
dtype
()
==
expert_tokens
.
dtype
());
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
auto
egrad
=
torch
::
zeros
({
e
*
c
,
h
},
auto
egrad
=
torch
::
zeros
(
{
e
*
c
,
h
},
torch
::
dtype
(
tokens_grad
.
dtype
()).
device
(
tokens_grad
.
device
())),
torch
::
dtype
(
tokens_grad
.
dtype
()).
device
(
tokens_grad
.
device
())),
wgrad
=
torch
::
zeros
({
s
,
e
},
torch
::
dtype
(
logits
.
dtype
()).
device
(
logits
.
device
()));
wgrad
=
torch
::
zeros
(
{
s
,
e
},
torch
::
dtype
(
logits
.
dtype
()).
device
(
logits
.
device
()));
auto
k
=
mask
.
size
(
0
);
auto
k
=
mask
.
size
(
0
);
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
(
tokens_grad
.
scalar_type
(),
"moe combine backward"
,
tokens_grad
.
scalar_type
(),
"moe combine backward"
,
moe_cb_bwd_launch
<
scalar_t
>
(
moe_cb_bwd_launch
<
scalar_t
>
(
tokens_grad
.
data
<
scalar_t
>
(),
egrad
.
data
<
scalar_t
>
(),
expert_tokens
.
data
<
scalar_t
>
(),
tokens_grad
.
data
<
scalar_t
>
(),
egrad
.
data
<
scalar_t
>
(),
logit
s
.
data
<
scalar_t
>
(),
wgrad
.
data
<
scalar_t
>
(),
expert_token
s
.
data
<
scalar_t
>
(),
logits
.
data
<
scalar_t
>
(),
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
wgrad
.
data
<
scalar_t
>
(),
mask
[
0
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
()
,
dest_idx
[
0
].
data
<
int
>
(),
s
,
e
,
c
,
h
)
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
s
,
e
,
c
,
);
h
)
);
return
{
egrad
,
wgrad
};
return
{
egrad
,
wgrad
};
}
}
...
@@ -695,7 +678,8 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
...
@@ -695,7 +678,8 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert
(
mask
.
dtype
()
==
torch
::
kInt32
);
assert
(
mask
.
dtype
()
==
torch
::
kInt32
);
const
int
s
=
mask
.
size
(
0
),
e
=
mask
.
size
(
1
);
const
int
s
=
mask
.
size
(
0
),
e
=
mask
.
size
(
1
);
auto
res
=
torch
::
empty
({
s
,
e
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
mask
.
device
()));
auto
res
=
torch
::
empty
({
s
,
e
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
mask
.
device
()));
cumsum_launch
(
mask
.
data
<
int
>
(),
res
.
data
<
int
>
(),
s
,
e
);
cumsum_launch
(
mask
.
data
<
int
>
(),
res
.
data
<
int
>
(),
s
,
e
);
return
res
;
return
res
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment