Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9587b080
Unverified
Commit
9587b080
authored
Dec 23, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 23, 2022
Browse files
[builder] use runtime builder for fused_optim (#2189)
parent
ce3c4eca
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
32 additions
and
18 deletions
+32
-18
colossalai/nn/optimizer/fused_adam.py
colossalai/nn/optimizer/fused_adam.py
+6
-3
colossalai/nn/optimizer/fused_lamb.py
colossalai/nn/optimizer/fused_lamb.py
+8
-3
colossalai/nn/optimizer/fused_sgd.py
colossalai/nn/optimizer/fused_sgd.py
+6
-3
colossalai/utils/common.py
colossalai/utils/common.py
+7
-7
tests/test_optimizer/test_fused_adam_kernel.py
tests/test_optimizer/test_fused_adam_kernel.py
+5
-2
No files found.
colossalai/nn/optimizer/fused_adam.py
View file @
9587b080
...
@@ -65,11 +65,14 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -65,11 +65,14 @@ class FusedAdam(torch.optim.Optimizer):
self
.
adamw_mode
=
1
if
adamw_mode
else
0
self
.
adamw_mode
=
1
if
adamw_mode
else
0
self
.
set_grad_none
=
set_grad_none
self
.
set_grad_none
=
set_grad_none
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
import
colossalai._C.fused_optim
try
:
from
colossalai._C
import
fused_optim
except
:
from
colossalai.kernel.op_builder.fused_optim
import
FusedOptimBuilder
fused_optim
=
FusedOptimBuilder
().
load
()
# Skip buffer
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_adam
=
colossalai
.
_C
.
fused_optim
.
multi_tensor_adam
self
.
multi_tensor_adam
=
fused_optim
.
multi_tensor_adam
else
:
else
:
raise
RuntimeError
(
'FusedAdam requires cuda extensions'
)
raise
RuntimeError
(
'FusedAdam requires cuda extensions'
)
...
...
colossalai/nn/optimizer/fused_lamb.py
View file @
9587b080
...
@@ -76,13 +76,18 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -76,13 +76,18 @@ class FusedLAMB(torch.optim.Optimizer):
max_grad_norm
=
max_grad_norm
)
max_grad_norm
=
max_grad_norm
)
super
(
FusedLAMB
,
self
).
__init__
(
params
,
defaults
)
super
(
FusedLAMB
,
self
).
__init__
(
params
,
defaults
)
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
import
colossalai._C.fused_optim
try
:
self
.
multi_tensor_l2norm
=
colossalai
.
_C
.
fused_optim
.
multi_tensor_l2norm
from
colossalai._C
import
fused_optim
except
:
from
colossalai.kernel.op_builder.fused_optim
import
FusedOptimBuilder
fused_optim
=
FusedOptimBuilder
().
load
()
self
.
multi_tensor_l2norm
=
fused_optim
.
multi_tensor_l2norm
# Skip buffer
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
tensor
([
0
],
self
.
_dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
param_groups
[
0
][
"params"
][
0
].
device
)
device
=
self
.
param_groups
[
0
][
"params"
][
0
].
device
)
self
.
multi_tensor_lamb
=
colossalai
.
_C
.
fused_optim
.
multi_tensor_lamb
self
.
multi_tensor_lamb
=
fused_optim
.
multi_tensor_lamb
else
:
else
:
raise
RuntimeError
(
'FusedLAMB requires cuda extensions'
)
raise
RuntimeError
(
'FusedLAMB requires cuda extensions'
)
...
...
colossalai/nn/optimizer/fused_sgd.py
View file @
9587b080
...
@@ -80,13 +80,16 @@ class FusedSGD(Optimizer):
...
@@ -80,13 +80,16 @@ class FusedSGD(Optimizer):
self
.
wd_after_momentum
=
wd_after_momentum
self
.
wd_after_momentum
=
wd_after_momentum
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
import
colossalai._C.fused_optim
try
:
from
colossalai._C
import
fused_optim
except
:
from
colossalai.kernel.op_builder
import
FusedOptimBuilder
fused_optim
=
FusedOptimBuilder
().
load
()
# Skip buffer
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
tensor
([
0
],
self
.
_dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
param_groups
[
0
][
"params"
][
0
].
device
)
device
=
self
.
param_groups
[
0
][
"params"
][
0
].
device
)
self
.
multi_tensor_sgd
=
colossalai
.
_C
.
fused_optim
.
multi_tensor_sgd
self
.
multi_tensor_sgd
=
fused_optim
.
multi_tensor_sgd
else
:
else
:
raise
RuntimeError
(
'FusedSGD requires cuda extensions'
)
raise
RuntimeError
(
'FusedSGD requires cuda extensions'
)
...
...
colossalai/utils/common.py
View file @
9587b080
...
@@ -12,9 +12,10 @@ from torch._six import inf
...
@@ -12,9 +12,10 @@ from torch._six import inf
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
try
:
try
:
import
colossalai._C
.
fused_optim
from
colossalai._C
import
fused_optim
except
:
except
:
pass
from
colossalai.kernel.op_builder
import
FusedOptimBuilder
fused_optim
=
FusedOptimBuilder
().
load
()
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -133,7 +134,7 @@ def _calc_l2_norm(grads):
...
@@ -133,7 +134,7 @@ def _calc_l2_norm(grads):
if
len
(
grads
)
>
0
:
if
len
(
grads
)
>
0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
norm
,
_
=
multi_tensor_applier
(
norm
,
_
=
multi_tensor_applier
(
colossalai
.
_C
.
fused_optim
.
multi_tensor_l2norm
,
fused_optim
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
dummy_overflow_buf
,
[
grads
],
[
grads
],
False
# no per-parameter norm
False
# no per-parameter norm
...
@@ -270,8 +271,8 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
...
@@ -270,8 +271,8 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
cpu_grads
.
append
(
p
.
grad
.
detach
())
cpu_grads
.
append
(
p
.
grad
.
detach
())
if
len
(
cuda_grads
)
>
0
:
if
len
(
cuda_grads
)
>
0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
multi_tensor_applier
(
colossalai
.
_C
.
fused_optim
.
multi_tensor_scale
,
dummy_overflow_buf
,
multi_tensor_applier
(
fused_optim
.
multi_tensor_scale
,
dummy_overflow_buf
,
[
cuda_grads
,
cuda_grads
],
[
cuda_grads
,
cuda_grads
],
clip_coef
)
clip_coef
)
for
g
in
cpu_grads
:
for
g
in
cpu_grads
:
g
.
mul_
(
clip_coef
)
g
.
mul_
(
clip_coef
)
...
@@ -397,8 +398,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
...
@@ -397,8 +398,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if
enable_cuda_kernels
:
if
enable_cuda_kernels
:
grads
=
[
p
.
grad
.
detach
()
for
p
in
params
]
grads
=
[
p
.
grad
.
detach
()
for
p
in
params
]
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
multi_tensor_applier
(
colossalai
.
_C
.
fused_optim
.
multi_tensor_scale
,
dummy_overflow_buf
,
[
grads
,
grads
],
multi_tensor_applier
(
fused_optim
.
multi_tensor_scale
,
dummy_overflow_buf
,
[
grads
,
grads
],
clip_coeff
)
clip_coeff
)
else
:
else
:
for
p
in
params
:
for
p
in
params
:
p
.
grad
.
detach
().
mul_
(
clip_coeff
)
p
.
grad
.
detach
().
mul_
(
clip_coeff
)
...
...
tests/test_optimizer/test_fused_adam_kernel.py
View file @
9587b080
...
@@ -49,9 +49,12 @@ def test_adam(adamw, step, p_dtype, g_dtype):
...
@@ -49,9 +49,12 @@ def test_adam(adamw, step, p_dtype, g_dtype):
try
:
try
:
import
colossalai._C.fused_optim
import
colossalai._C.fused_optim
fused_adam
=
colossalai
.
_C
.
fused_optim
.
multi_tensor_adam
fused_adam
=
colossalai
.
_C
.
fused_optim
.
multi_tensor_adam
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
except
:
except
:
raise
ImportError
(
"No colossalai._C.fused_optim kernel installed."
)
from
colossalai.kernel.op_builder
import
FusedOptimBuilder
fused_optim
=
FusedOptimBuilder
().
load
()
fused_adam
=
fused_optim
.
multi_tensor_adam
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
count
=
0
count
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment