Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
16cc8e6a
Unverified
Commit
16cc8e6a
authored
Jan 03, 2023
by
Jiarui Fang
Committed by
GitHub
Jan 03, 2023
Browse files
[builder] MOE builder (#2277)
parent
26e171af
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
60 additions
and
20 deletions
+60
-20
colossalai/kernel/__init__.py
colossalai/kernel/__init__.py
+14
-2
colossalai/kernel/op_builder/__init__.py
colossalai/kernel/op_builder/__init__.py
+2
-1
colossalai/kernel/op_builder/builder.py
colossalai/kernel/op_builder/builder.py
+2
-2
colossalai/kernel/op_builder/moe.py
colossalai/kernel/op_builder/moe.py
+33
-0
colossalai/nn/layer/moe/_operation.py
colossalai/nn/layer/moe/_operation.py
+6
-12
setup.py
setup.py
+3
-3
No files found.
colossalai/kernel/__init__.py
View file @
16cc8e6a
...
@@ -24,7 +24,19 @@ except ImportError:
...
@@ -24,7 +24,19 @@ except ImportError:
from
colossalai.kernel.op_builder
import
ScaledSoftmaxBuilder
from
colossalai.kernel.op_builder
import
ScaledSoftmaxBuilder
scaled_upper_triang_masked_softmax
=
ScaledSoftmaxBuilder
().
load
()
scaled_upper_triang_masked_softmax
=
ScaledSoftmaxBuilder
().
load
()
try
:
from
colossalai._C
import
moe
except
ImportError
:
from
colossalai.kernel.op_builder
import
MOEBuilder
moe
=
MOEBuilder
().
load
()
__all__
=
[
__all__
=
[
"fused_optim"
,
"cpu_optim"
,
"multihead_attention"
,
"LayerNorm"
,
"FusedScaleMaskSoftmax"
,
"MultiHeadAttention"
,
"fused_optim"
,
"scaled_upper_triang_masked_softmax"
"cpu_optim"
,
"multihead_attention"
,
"moe"
,
"LayerNorm"
,
"FusedScaleMaskSoftmax"
,
"MultiHeadAttention"
,
"scaled_upper_triang_masked_softmax"
,
]
]
colossalai/kernel/op_builder/__init__.py
View file @
16cc8e6a
from
.cpu_adam
import
CPUAdamBuilder
from
.cpu_adam
import
CPUAdamBuilder
from
.fused_optim
import
FusedOptimBuilder
from
.fused_optim
import
FusedOptimBuilder
from
.moe
import
MOEBuilder
from
.multi_head_attn
import
MultiHeadAttnBuilder
from
.multi_head_attn
import
MultiHeadAttnBuilder
from
.scaled_upper_triang_masked_softmax
import
ScaledSoftmaxBuilder
from
.scaled_upper_triang_masked_softmax
import
ScaledSoftmaxBuilder
__all__
=
[
'CPUAdamBuilder'
,
'FusedOptimBuilder'
,
'MultiHeadAttnBuilder'
,
'ScaledSoftmaxBuilder'
]
__all__
=
[
'CPUAdamBuilder'
,
'FusedOptimBuilder'
,
'MultiHeadAttnBuilder'
,
'ScaledSoftmaxBuilder'
,
'MOEBuilder'
]
colossalai/kernel/op_builder/builder.py
View file @
16cc8e6a
import
os
import
os
import
re
import
re
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
import
torch
import
torch
def
get_cuda_cc_flag
():
def
get_cuda_cc_flag
()
->
List
:
"""get_cuda_cc_flag
"""get_cuda_cc_flag
cc flag for your GPU arch
cc flag for your GPU arch
...
...
colossalai/kernel/op_builder/moe.py
0 → 100644
View file @
16cc8e6a
import
os
from
.builder
import
Builder
,
get_cuda_cc_flag
class
MOEBuilder
(
Builder
):
def
__init__
(
self
):
self
.
base_dir
=
"cuda_native/csrc"
self
.
name
=
'moe'
super
().
__init__
()
def
include_dirs
(
self
):
ret
=
[]
ret
=
[
os
.
path
.
join
(
self
.
base_dir
,
"includes"
),
self
.
get_cuda_home_include
()]
ret
.
append
(
os
.
path
.
join
(
self
.
base_dir
,
"kernels"
,
"include"
))
return
[
self
.
colossalai_src_path
(
path
)
for
path
in
ret
]
def
sources_files
(
self
):
ret
=
[
os
.
path
.
join
(
self
.
base_dir
,
fname
)
for
fname
in
[
'moe_cuda.cpp'
,
'moe_cuda_kernel.cu'
]]
return
[
self
.
colossalai_src_path
(
path
)
for
path
in
ret
]
def
cxx_flags
(
self
):
return
[
'-O3'
,
'-DVERSION_GE_1_1'
,
'-DVERSION_GE_1_3'
,
'-DVERSION_GE_1_5'
]
def
nvcc_flags
(
self
):
extra_cuda_flags
=
[
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
]
extra_cuda_flags
.
extend
(
get_cuda_cc_flag
())
ret
=
[
'-O3'
,
'--use_fast_math'
]
+
extra_cuda_flags
return
ret
colossalai/nn/layer/moe/_operation.py
View file @
16cc8e6a
...
@@ -6,12 +6,7 @@ from torch import Tensor
...
@@ -6,12 +6,7 @@ from torch import Tensor
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
COL_MOE_KERNEL_FLAG
=
False
COL_MOE_KERNEL_FLAG
=
False
try
:
from
colossalai.kernel
import
moe
import
colossalai._C.moe
COL_MOE_KERNEL_FLAG
=
True
except
ImportError
:
print
(
"If you want to activate cuda mode for MoE, please install with cuda_ext!"
)
class
AllGather
(
torch
.
autograd
.
Function
):
class
AllGather
(
torch
.
autograd
.
Function
):
...
@@ -90,7 +85,7 @@ class MoeDispatch(torch.autograd.Function):
...
@@ -90,7 +85,7 @@ class MoeDispatch(torch.autograd.Function):
s
=
tokens
.
size
(
0
)
s
=
tokens
.
size
(
0
)
h
=
tokens
.
size
(
1
)
h
=
tokens
.
size
(
1
)
expert_input
=
colossalai
.
_C
.
moe
.
dispatch_forward
(
s
,
ec
,
h
,
tokens
,
mask
,
dest_idx
)
expert_input
=
moe
.
dispatch_forward
(
s
,
ec
,
h
,
tokens
,
mask
,
dest_idx
)
ctx
.
save_for_backward
(
mask
,
dest_idx
)
ctx
.
save_for_backward
(
mask
,
dest_idx
)
ctx
.
s
=
s
ctx
.
s
=
s
...
@@ -102,7 +97,7 @@ class MoeDispatch(torch.autograd.Function):
...
@@ -102,7 +97,7 @@ class MoeDispatch(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
output_grad
):
def
backward
(
ctx
,
output_grad
):
mask
,
dest_idx
=
ctx
.
saved_tensors
mask
,
dest_idx
=
ctx
.
saved_tensors
d_tokens
=
colossalai
.
_C
.
moe
.
dispatch_backward
(
ctx
.
s
,
ctx
.
ec
,
ctx
.
h
,
output_grad
,
mask
,
dest_idx
)
d_tokens
=
moe
.
dispatch_backward
(
ctx
.
s
,
ctx
.
ec
,
ctx
.
h
,
output_grad
,
mask
,
dest_idx
)
return
d_tokens
,
None
,
None
,
None
return
d_tokens
,
None
,
None
,
None
...
@@ -119,7 +114,7 @@ class MoeCombine(torch.autograd.Function):
...
@@ -119,7 +114,7 @@ class MoeCombine(torch.autograd.Function):
fp16_flag
=
(
expert_tokens
.
dtype
==
torch
.
float16
)
fp16_flag
=
(
expert_tokens
.
dtype
==
torch
.
float16
)
cb_input
=
expert_tokens
.
to
(
torch
.
float32
)
if
fp16_flag
else
expert_tokens
cb_input
=
expert_tokens
.
to
(
torch
.
float32
)
if
fp16_flag
else
expert_tokens
ctokens
=
colossalai
.
_C
.
moe
.
combine_forward
(
s
,
e
,
c
,
h
,
cb_input
,
logits
,
mask
,
dest_idx
)
ctokens
=
moe
.
combine_forward
(
s
,
e
,
c
,
h
,
cb_input
,
logits
,
mask
,
dest_idx
)
output
=
ctokens
.
to
(
torch
.
float16
)
if
fp16_flag
else
ctokens
output
=
ctokens
.
to
(
torch
.
float16
)
if
fp16_flag
else
ctokens
ctx
.
save_for_backward
(
expert_tokens
,
logits
,
mask
,
dest_idx
)
ctx
.
save_for_backward
(
expert_tokens
,
logits
,
mask
,
dest_idx
)
...
@@ -138,8 +133,7 @@ class MoeCombine(torch.autograd.Function):
...
@@ -138,8 +133,7 @@ class MoeCombine(torch.autograd.Function):
cb_grad
=
tokens_grad
.
to
(
torch
.
float32
)
if
tokens_grad
.
dtype
is
torch
.
float16
\
cb_grad
=
tokens_grad
.
to
(
torch
.
float32
)
if
tokens_grad
.
dtype
is
torch
.
float16
\
else
tokens_grad
else
tokens_grad
cb_input
=
expert_tokens
.
to
(
torch
.
float32
)
if
ctx
.
fp16_flag
else
expert_tokens
cb_input
=
expert_tokens
.
to
(
torch
.
float32
)
if
ctx
.
fp16_flag
else
expert_tokens
d_expert
,
d_logits
=
colossalai
.
_C
.
moe
.
combine_backward
(
ctx
.
s
,
ctx
.
e
,
ctx
.
c
,
ctx
.
h
,
cb_grad
,
cb_input
,
logits
,
d_expert
,
d_logits
=
moe
.
combine_backward
(
ctx
.
s
,
ctx
.
e
,
ctx
.
c
,
ctx
.
h
,
cb_grad
,
cb_input
,
logits
,
mask
,
dest_idx
)
mask
,
dest_idx
)
d_expert
=
d_expert
.
to
(
torch
.
float16
)
if
ctx
.
fp16_flag
else
d_expert
d_expert
=
d_expert
.
to
(
torch
.
float16
)
if
ctx
.
fp16_flag
else
d_expert
return
d_expert
,
d_logits
,
None
,
None
,
None
return
d_expert
,
d_logits
,
None
,
None
,
None
...
@@ -149,6 +143,6 @@ def moe_cumsum(inputs: Tensor):
...
@@ -149,6 +143,6 @@ def moe_cumsum(inputs: Tensor):
dim0
=
inputs
.
size
(
0
)
dim0
=
inputs
.
size
(
0
)
flag
=
(
dim0
<=
1024
)
or
(
dim0
<=
2048
and
dim0
%
2
==
0
)
or
(
dim0
%
4
==
0
)
flag
=
(
dim0
<=
1024
)
or
(
dim0
<=
2048
and
dim0
%
2
==
0
)
or
(
dim0
%
4
==
0
)
if
flag
and
COL_MOE_KERNEL_FLAG
:
if
flag
and
COL_MOE_KERNEL_FLAG
:
return
colossalai
.
_C
.
moe
.
cumsum_sub_one
(
inputs
)
return
moe
.
cumsum_sub_one
(
inputs
)
else
:
else
:
return
torch
.
cumsum
(
inputs
,
dim
=
0
)
-
1
return
torch
.
cumsum
(
inputs
,
dim
=
0
)
-
1
setup.py
View file @
16cc8e6a
import
os
import
os
import
re
import
re
from
setuptools
import
Extension
,
find_packages
,
setup
from
setuptools
import
find_packages
,
setup
from
colossalai.kernel.op_builder.utils
import
get_cuda_bare_metal_version
from
colossalai.kernel.op_builder.utils
import
get_cuda_bare_metal_version
...
@@ -161,8 +161,8 @@ if build_cuda_ext:
...
@@ -161,8 +161,8 @@ if build_cuda_ext:
cuda_ext_helper
(
'colossalai._C.scaled_masked_softmax'
,
cuda_ext_helper
(
'colossalai._C.scaled_masked_softmax'
,
[
'scaled_masked_softmax.cpp'
,
'scaled_masked_softmax_cuda.cu'
],
extra_cuda_flags
+
cc_flag
))
[
'scaled_masked_softmax.cpp'
,
'scaled_masked_softmax_cuda.cu'
],
extra_cuda_flags
+
cc_flag
))
ext_modules
.
append
(
from
colossalai.kernel.op_builder
import
MOEBuilder
cuda_ext_helper
(
'colossalai._C.moe'
,
[
'moe_cuda.cpp'
,
'moe_cuda_kernel.cu'
],
extra_cuda_flags
+
cc_flag
))
ext_modules
.
append
(
MOEBuilder
().
builder
(
'colossalai._C.moe'
))
extra_cuda_flags
=
[
'-maxrregcount=50'
]
extra_cuda_flags
=
[
'-maxrregcount=50'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment