Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
db4cbdc7
"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "b5f9e37c709656b286940f1b5e05abddfa257e3d"
Unverified
Commit
db4cbdc7
authored
Dec 30, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 30, 2022
Browse files
[builder] builder for scaled_upper_triang_masked_softmax (#2234)
parent
31fe8423
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
53 additions
and
18 deletions
+53
-18
colossalai/kernel/__init__.py
colossalai/kernel/__init__.py
+8
-1
colossalai/kernel/cuda_native/scaled_softmax.py
colossalai/kernel/cuda_native/scaled_softmax.py
+4
-11
colossalai/kernel/op_builder/__init__.py
colossalai/kernel/op_builder/__init__.py
+2
-1
colossalai/kernel/op_builder/scaled_upper_triang_masked_softmax.py
...i/kernel/op_builder/scaled_upper_triang_masked_softmax.py
+36
-0
examples/language/gpt/train_gpt_demo.py
examples/language/gpt/train_gpt_demo.py
+1
-1
setup.py
setup.py
+2
-4
No files found.
colossalai/kernel/__init__.py
View file @
db4cbdc7
...
@@ -18,6 +18,13 @@ except ImportError:
...
@@ -18,6 +18,13 @@ except ImportError:
from
colossalai.kernel.op_builder
import
MultiHeadAttnBuilder
from
colossalai.kernel.op_builder
import
MultiHeadAttnBuilder
multihead_attention
=
MultiHeadAttnBuilder
().
load
()
multihead_attention
=
MultiHeadAttnBuilder
().
load
()
try
:
from
colossalai._C
import
scaled_upper_triang_masked_softmax
except
ImportError
:
from
colossalai.kernel.op_builder
import
ScaledSoftmaxBuilder
scaled_upper_triang_masked_softmax
=
ScaledSoftmaxBuilder
().
load
()
__all__
=
[
__all__
=
[
"fused_optim"
,
"cpu_optim"
,
"multihead_attention"
,
"LayerNorm"
,
"FusedScaleMaskSoftmax"
,
"MultiHeadAttention"
"fused_optim"
,
"cpu_optim"
,
"multihead_attention"
,
"LayerNorm"
,
"FusedScaleMaskSoftmax"
,
"MultiHeadAttention"
,
"scaled_upper_triang_masked_softmax"
]
]
colossalai/kernel/cuda_native/scaled_softmax.py
View file @
db4cbdc7
...
@@ -23,27 +23,20 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -23,27 +23,20 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
def
forward
(
ctx
,
inputs
,
scale
):
try
:
from
colossalai.kernel
import
scaled_upper_triang_masked_softmax
import
colossalai._C.scaled_upper_triang_masked_softmax
except
ImportError
:
raise
RuntimeError
(
'ScaledUpperTriangMaskedSoftmax requires cuda extensions'
)
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
colossalai
.
_C
.
scaled_upper_triang_masked_softmax
.
forward
(
inputs
,
scale_t
[
0
])
softmax_results
=
scaled_upper_triang_masked_softmax
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
output_grads
):
def
backward
(
ctx
,
output_grads
):
try
:
from
colossalai.kernel
import
scaled_upper_triang_masked_softmax
import
colossalai._C.scaled_upper_triang_masked_softmax
except
ImportError
:
raise
RuntimeError
(
'ScaledUpperTriangMaskedSoftmax requires cuda extensions'
)
softmax_results
,
scale_t
=
ctx
.
saved_tensors
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
colossalai
.
_C
.
scaled_upper_triang_masked_softmax
.
backward
(
output_grads
,
softmax_results
,
input_grads
=
scaled_upper_triang_masked_softmax
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
scale_t
[
0
])
return
input_grads
,
None
return
input_grads
,
None
...
...
colossalai/kernel/op_builder/__init__.py
View file @
db4cbdc7
from
.cpu_adam
import
CPUAdamBuilder
from
.cpu_adam
import
CPUAdamBuilder
from
.fused_optim
import
FusedOptimBuilder
from
.fused_optim
import
FusedOptimBuilder
from
.multi_head_attn
import
MultiHeadAttnBuilder
from
.multi_head_attn
import
MultiHeadAttnBuilder
from
.scaled_upper_triang_masked_softmax
import
ScaledSoftmaxBuilder
__all__
=
[
'CPUAdamBuilder'
,
'FusedOptimBuilder'
,
'MultiHeadAttnBuilder'
]
__all__
=
[
'CPUAdamBuilder'
,
'FusedOptimBuilder'
,
'MultiHeadAttnBuilder'
,
'ScaledSoftmaxBuilder'
]
colossalai/kernel/op_builder/scaled_upper_triang_masked_softmax.py
0 → 100644
View file @
db4cbdc7
import
os
from
.builder
import
Builder
,
get_cuda_cc_flag
class
ScaledSoftmaxBuilder
(
Builder
):
def
__init__
(
self
):
self
.
base_dir
=
"cuda_native/csrc"
self
.
name
=
'scaled_upper_triang_masked_softmax'
super
().
__init__
()
def
include_dirs
(
self
):
ret
=
[]
ret
=
[
os
.
path
.
join
(
self
.
base_dir
,
"includes"
),
self
.
get_cuda_home_include
()]
ret
.
append
(
os
.
path
.
join
(
self
.
base_dir
,
"kernels"
,
"include"
))
return
[
self
.
colossalai_src_path
(
path
)
for
path
in
ret
]
def
sources_files
(
self
):
ret
=
[
os
.
path
.
join
(
self
.
base_dir
,
fname
)
for
fname
in
[
'scaled_upper_triang_masked_softmax.cpp'
,
'scaled_upper_triang_masked_softmax_cuda.cu'
]
]
return
[
self
.
colossalai_src_path
(
path
)
for
path
in
ret
]
def
cxx_flags
(
self
):
return
[
'-O3'
]
def
nvcc_flags
(
self
):
extra_cuda_flags
=
[
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
]
extra_cuda_flags
.
extend
(
get_cuda_cc_flag
())
ret
=
[
'-O3'
,
'--use_fast_math'
]
+
extra_cuda_flags
return
ret
examples/language/gpt/train_gpt_demo.py
View file @
db4cbdc7
...
@@ -324,7 +324,7 @@ def main():
...
@@ -324,7 +324,7 @@ def main():
if
n
>=
WARMUP_STEPS
:
if
n
>=
WARMUP_STEPS
:
tflops_list
.
append
(
step_tflops
)
tflops_list
.
append
(
step_tflops
)
logger
.
info
(
f
"max memory
{
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
}
MB"
,
ranks
=
[
0
])
logger
.
info
(
f
"max memory
{
torch
.
cuda
.
max_
memory_allocated
()
/
1024
**
2
}
MB"
,
ranks
=
[
0
])
tflops_list
.
sort
()
tflops_list
.
sort
()
median_index
=
((
NUM_STEPS
-
WARMUP_STEPS
)
>>
1
)
+
WARMUP_STEPS
median_index
=
((
NUM_STEPS
-
WARMUP_STEPS
)
>>
1
)
+
WARMUP_STEPS
...
...
setup.py
View file @
db4cbdc7
...
@@ -154,10 +154,8 @@ if build_cuda_ext:
...
@@ -154,10 +154,8 @@ if build_cuda_ext:
'--expt-extended-lambda'
'--expt-extended-lambda'
]
]
ext_modules
.
append
(
from
colossalai.kernel.op_builder
import
ScaledSoftmaxBuilder
cuda_ext_helper
(
'colossalai._C.scaled_upper_triang_masked_softmax'
,
ext_modules
.
append
(
ScaledSoftmaxBuilder
().
builder
(
'colossalai._C.scaled_upper_triang_masked_softmax'
))
[
'scaled_upper_triang_masked_softmax.cpp'
,
'scaled_upper_triang_masked_softmax_cuda.cu'
],
extra_cuda_flags
+
cc_flag
))
ext_modules
.
append
(
ext_modules
.
append
(
cuda_ext_helper
(
'colossalai._C.scaled_masked_softmax'
,
cuda_ext_helper
(
'colossalai._C.scaled_masked_softmax'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment